Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ikergarcia1996 committed Sep 11, 2024
1 parent 041ef41 commit 4e4fd86
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 13 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,8 @@ dmypy.json

*.ipynb

models/
models/

.slurm/

wandb/
32 changes: 32 additions & 0 deletions scripts/finetune.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/bin/bash
#SBATCH --job-name=Odesia-finetune
#SBATCH --cpus-per-task=16
#SBATCH --nodes=1
#SBATCH --gres=gpu:2
#SBATCH --mem=64G
#SBATCH --output=.slurm/Odesia-finetune.out.txt
#SBATCH --error=.slurm/Odesia-finetune.err.txt


source /ikerlariak/igarcia945/envs/pytorch2/bin/activate


export LC_ALL=en_US.UTF-8
export LANG=en_US.UTF-8
export LANGUAGE=en_US.UTF-8
export TOKENIZERS_PARALLELISM=true
export TRANSFORMERS_NO_ADVISORY_WARNINGS=true
export WANDB_ENTITY=igarciaf
export WANDB_PROJECT=Odesia
export OMP_NUM_THREADS=16
export WANDB__SERVICE_WAIT=300

echo CUDA_VISIBLE_DEVICES "${CUDA_VISIBLE_DEVICES}"


export PYTHONPATH="$PYTHONPATH:$PWD"
torchrun --standalone --master_port 37227 --nproc_per_node=2 src/train.py train_configs/gemma2B.yaml
torchrun --standalone --master_port 37227 --nproc_per_node=2 src/train.py train_configs/llama8b.yaml

torchrun --standalone --master_port 37227 --nproc_per_node=2 src/evaluate.py --tasks all --model_name models/gemma-2b-it --output_dir results/finetune/gemma-2b-it
torchrun --standalone --master_port 37227 --nproc_per_node=2 src/evaluate.py --tasks all --model_name models/Llama-3.1-8B-Instruct --output_dir results/finetune/Llama-3.1-8B-Instruct
20 changes: 9 additions & 11 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from datasets import load_dataset
from tqdm import tqdm
from transformers import HfArgumentParser
from trl import SFTConfig, SFTTrainer, setup_chat_format
from trl import SFTConfig, SFTTrainer

from src.config.config import ModelArguments
from src.model.load_model import load_model, merge_lora_model
from src.model.load_model import load_model
from src.tasks import get_tasks


Expand All @@ -34,13 +34,13 @@ def train(training_args: SFTConfig, model_args: ModelArguments):
rope_scaling_factor=model_args.rope_scaling_factor,
)

tasks = get_tasks(tokenizer=tokenizer, tasks="all")
tasks = get_tasks(tokenizer=tokenizer, tasks=["exist_2022_t1_es"])

train_dataset = []
validation_dataset = []
for task_name, task in tqdm(tasks.items(), desc="Loading datasets"):
print(f"Loading dataset for task {task_name}")
train = task.get_dataset_training(split="train")
train = task.get_dataset_training(split="train")[:100]
train_dataset.extend(train)
print(f"Train dataset size: {len(train)}")
dev = task.get_dataset_training(split="dev")
Expand All @@ -60,7 +60,6 @@ def train(training_args: SFTConfig, model_args: ModelArguments):
for example in validation_dataset:
print(json.dumps(example, ensure_ascii=False), file=f)

model, tokenizer = setup_chat_format(model, tokenizer)
dataset = load_dataset(
"json",
data_files={
Expand All @@ -71,20 +70,19 @@ def train(training_args: SFTConfig, model_args: ModelArguments):
},
)

print(dataset)

trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=dataset,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
)

trainer.train()

merge_lora_model(
weights_path=model_args.model_name_or_path,
lora_weights_name_or_path=training_args.output_dir,
output_path=training_args.output_dir,
)
trainer.save_model()


if __name__ == "__main__":
Expand Down
117 changes: 117 additions & 0 deletions src/train2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import logging
import os
import sys

from datasets import load_dataset
from transformers import AutoTokenizer, HfArgumentParser
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel # Also reqires pip install xformers

from src.config.config import ModelArguments
from src.model.model_utils import find_all_linear_names


def train(training_args: SFTConfig, model_args: ModelArguments):
os.makedirs(training_args.output_dir, exist_ok=True)

model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_args.model_name_or_path,
dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False
)

tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path
) # FastLanguageModel doesn't load the chat template

model = FastLanguageModel.get_peft_model(
model,
r=model_args.lora_r,
target_modules=find_all_linear_names(model),
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout, # Dropout = 0 is currently optimized
bias="none", # Bias = "none" is currently optimized
use_gradient_checkpointing=True,
random_state=3407,
)

"""
tasks = get_tasks(tokenizer=tokenizer, tasks="all")
train_dataset = []
validation_dataset = []
for task_name, task in tqdm(tasks.items(), desc="Loading datasets"):
print(f"Loading dataset for task {task_name}")
train = task.get_dataset_training(split="train")
train_dataset.extend(train)
print(f"Train dataset size: {len(train)}")
dev = task.get_dataset_training(split="dev")
validation_dataset.extend(dev)
print(f"Validation dataset size: {len(dev)}")
with open(os.path.join(training_args.output_dir, "train_dataset.jsonl"), "w") as f:
for example in train_dataset:
print(json.dumps(example, ensure_ascii=False), file=f)
print(f"Full training dataset size: {len(train_dataset)}")
print(f"Full validation dataset size: {len(validation_dataset)}")
with open(
os.path.join(training_args.output_dir, "validation_dataset.jsonl"), "w"
) as f:
for example in validation_dataset:
print(json.dumps(example, ensure_ascii=False), file=f)
"""
print("Loading datasets")
dataset = load_dataset(
"json",
data_files={
"train": os.path.join(training_args.output_dir, "train_dataset.jsonl"),
"validation": os.path.join(
training_args.output_dir, "validation_dataset.jsonl"
),
},
)

print(dataset)

trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
)

trainer.train()

trainer.save_model()


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)

parser = HfArgumentParser((ModelArguments, SFTConfig))
logging.info(f"Sys args {sys.argv}")

if len(sys.argv) > 0 and sys.argv[-1].endswith(".json"):
# If we pass only one argument to the script, and it's the path to a json file,
# let's parse it to get our arguments.
logging.info(f"Loading json config {sys.argv[-1]}")
model_args, training_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[-1])
)

elif len(sys.argv) > 0 and sys.argv[-1].endswith(".yaml"):
# If we pass only one argument to the script, and it's the path to a yaml file,
# let's parse it to get our arguments.
logging.info(f"Loading yaml config {sys.argv[-1]}")
model_args, training_args = parser.parse_yaml_file(
yaml_file=os.path.abspath(sys.argv[-1])
)
else:
logging.info("No config file passed, using command line arguments.")
model_args, training_args = parser.parse_args_into_dataclasses()

train(training_args, model_args)
2 changes: 1 addition & 1 deletion train_configs/gemma2B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ quantization: null
gradient_checkpointing: true
force_auto_device_map: false
use_flash_attention: true
packing: true
packing: false

output_dir: models/gemma-2b-it
overwrite_output_dir: true
Expand Down
62 changes: 62 additions & 0 deletions train_configs/llama8b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#Training args
model_name_or_path: meta-llama/Meta-Llama-3.1-8B-Instruct
torch_dtype: bfloat16
use_lora: true
quantization: null
gradient_checkpointing: true
force_auto_device_map: false
use_flash_attention: true
packing: false

output_dir: models/Llama-3.1-8B-Instruct
overwrite_output_dir: true
load_best_model_at_end: false
metric_for_best_model:
greater_is_better: true
save_strategy: "no"
save_only_model: true
save_total_limit: 1

# evaluation
do_train: true
do_eval: true
do_predict: false
evaluation_strategy: "epoch"

per_device_train_batch_size: 8
per_device_eval_batch_size: 2
gradient_accumulation_steps: 8

# optimizer settings
optim: adamw_torch_fused
learning_rate: 0.0003
weight_decay: 0.001
num_train_epochs: 3
lr_scheduler_type: cosine
warmup_ratio: 0.1
adam_epsilon: 0.0000001

# lora settings
lora_r: 128
lora_alpha: 256
lora_dropout: 0.05
lora_target_modules:
- all

# reporting
logging_strategy: steps
logging_first_step: true
logging_steps: 5
report_to: wandb
run_name: "Llama-3.1-8B-Instruct"
disable_tqdm: false

# hub settings
push_to_hub: false
resume_from_checkpoint: false

# performance
bf16: true
fp16: false
torch_compile: false
ddp_find_unused_parameters: false

0 comments on commit 4e4fd86

Please sign in to comment.