Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions functionary/train/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https
# Install Dependencies
pip install -e .

# Install flash-attention
pip install flash-attn==2.7.4.post1 --no-build-isolation

# Install Liger if using liger:
pip install -e .[liger]
```
Expand Down Expand Up @@ -209,3 +212,106 @@ After finish training, you can merge the Lora weights with the pretrained weight
```shell
python -m functionary.train.merge_lora_weight save_folder pretrained_path checkpoint model_max_length prompt_template_version
```



## DPO Training
For DPO training, you need to first install the trl: ```pip install trl==0.17.0```
Here is the example command:
```shell
export WANDB_PROJECT=functionary
export WANDB_LOG_MODEL=all
torchrun --nproc_per_node=1 functionary/train/train_dpo.py \
--model_name_or_path Qwen/Qwen3-4B \
--train_data_path gen_train.jsonl \
--eval_data_path gen_dev.jsonl \
--bf16 True \
--output_dir test_output_dir \
--num_train_epochs 1 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--eval_accumulation_steps 1 \
--eval_strategy steps \
--eval_steps 100 \
--save_strategy no \
--logging_steps 5 \
--learning_rate 1e-5 \
--weight_decay 0. \
--warmup_steps 35 \
--lr_scheduler_type cosine_with_min_lr \
--lr_scheduler_kwargs "{\"min_lr_rate\": 0.1}" \
--tf32 True \
--gradient_checkpointing True \
--optim paged_adamw_8bit \
--max_length 32168 \
--use_liger True \
--prompt_template_version qwen2.5-text-only
```

Using deepspeed:
```shell
export WANDB_PROJECT=functionary
export WANDB_LOG_MODEL=all
deepspeed functionary/train/train_dpo.py \
--model_name_or_path Qwen/Qwen3-4B \
--train_data_path gen_train.jsonl \
--eval_data_path gen_dev.jsonl \
--bf16 True \
--output_dir test_output_dir \
--num_train_epochs 1 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--eval_accumulation_steps 1 \
--eval_strategy steps \
--eval_steps 100 \
--save_strategy no \
--logging_steps 5 \
--learning_rate 1e-5 \
--weight_decay 0. \
--warmup_steps 35 \
--lr_scheduler_type cosine_with_min_lr \
--lr_scheduler_kwargs "{\"min_lr_rate\": 0.1}" \
--tf32 True \
--gradient_checkpointing True \
--optim paged_adamw_8bit \
--deepspeed functionary/train/ds_config/zero3_wo_offload.json \
--max_length 32768 \
--use_liger True \
--prompt_template_version qwen2.5-text-only \
--use_peft --lora_r 128 --lora_alpha 256 --lora_target_modules all-linear
```

Using Lora with Deepspeed:
```shell
export WANDB_PROJECT=functionary
export WANDB_LOG_MODEL=all
deepspeed functionary/train/train_dpo.py \
--model_name_or_path Qwen/Qwen3-4B \
--train_data_path gen_train.jsonl \
--eval_data_path gen_dev.jsonl \
--bf16 True \
--output_dir test_output_dir \
--num_train_epochs 1 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--eval_accumulation_steps 1 \
--eval_strategy steps \
--eval_steps 100 \
--save_strategy no \
--logging_steps 5 \
--learning_rate 1e-5 \
--weight_decay 0. \
--warmup_steps 35 \
--lr_scheduler_type cosine_with_min_lr \
--lr_scheduler_kwargs "{\"min_lr_rate\": 0.1}" \
--tf32 True \
--gradient_checkpointing True \
--optim paged_adamw_8bit \
--deepspeed functionary/train/ds_config/zero3_wo_offload.json \
--max_length 32768 \
--use_liger True \
--prompt_template_version qwen2.5-text-only
```
239 changes: 239 additions & 0 deletions functionary/train/train_dpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
import sys
import os

sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
from typing import Dict, Optional
import requests
import json
import random

# from torch.utils.data import Dataset
from datasets import Dataset
from transformers import AutoTokenizer, BitsAndBytesConfig
from transformers.modeling_utils import is_deepspeed_zero3_enabled

import transformers
import torch
from transformers.trainer_utils import is_main_process
from dataclasses import dataclass, field
from trl import DPOTrainer, DPOConfig, ModelConfig
from trl import get_kbit_device_map, get_peft_config, get_quantization_config
from transformers import TrainerCallback
import argparse
import os
from transformers import (
Trainer,
TrainingArguments,
TrainerCallback,
TrainerState,
TrainerControl,
)

import os
from huggingface_hub import HfApi
from typing import Callable, Optional
import bitsandbytes as bnb
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from functionary.prompt_template import get_prompt_template_by_version

LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))


def print_rank0(*arg):
if LOCAL_RANK == 0:
print(*arg)


# class FunctionaryDPODataset(Dataset):
# def __init__(self, data_path: str, prompt_template_version: str):
# with open(data_path, "r") as f:
# self.data = [json.loads(line) for line in f]
# # assume that data with the fields: tools; messages; chosen; rejected
# # chosen and rejected are assistant message; we will convert them to string
# self.list_prompts = []
# self.list_chosen = []
# self.list_rejected = []
# self.prompt_template = get_prompt_template_by_version(prompt_template_version)

# for item in self.data:
# messages = item["messages"]
# tools = item.get("tools", []) or []
# chosen = item["chosen"]
# rejected = item["rejected"]
# input_prompt = self.prompt_template.get_prompt_from_messages(
# messages, tools_or_functions=tools, add_generation_prompt=True
# )
# # compute the output prompt for chosen
# full_prompt_chosen = self.prompt_template.get_prompt_from_messages(
# messages + [chosen], tools_or_functions=tools
# )
# chosen_output = full_prompt_chosen[len(input_prompt) :]

# full_prompt_rejected = self.prompt_template.get_prompt_from_messages(
# messages + [rejected], tools_or_functions=tools
# )
# rejected_output = full_prompt_rejected[len(input_prompt) :]

# self.list_prompts.append(input_prompt)
# self.list_chosen.append(chosen_output)
# self.list_rejected.append(rejected_output)

# def __len__(self):
# return len(self.list_prompts)

# def __getitem__(self, idx):
# return {
# "prompt": self.list_prompts[idx],
# "chosen": self.list_chosen[idx],
# "rejected": self.list_rejected[idx],
# }


def get_dataset_from_jsonl(data_path: str, prompt_template_version: str):
with open(data_path, "r") as f:
data = [json.loads(line) for line in f]
list_prompts = []
list_chosen = []
list_rejected = []
prompt_template = get_prompt_template_by_version(prompt_template_version)

for item in data:
messages = item["messages"]
tools = item.get("tools", []) or []
chosen = item["selected_answer"]
rejected = item["rejected_answer"]
input_prompt = prompt_template.get_prompt_from_messages(
messages, tools_or_functions=tools, add_generation_prompt=True
)
# compute the output prompt for chosen
full_prompt_chosen = prompt_template.get_prompt_from_messages(
messages + [chosen], tools_or_functions=tools
)
chosen_output = full_prompt_chosen[len(input_prompt) :]

full_prompt_rejected = prompt_template.get_prompt_from_messages(
messages + [rejected], tools_or_functions=tools
)
rejected_output = full_prompt_rejected[len(input_prompt) :]

list_prompts.append(input_prompt)
list_chosen.append(chosen_output)
list_rejected.append(rejected_output)

return Dataset.from_dict(
{"prompt": list_prompts, "chosen": list_chosen, "rejected": list_rejected}
)


@dataclass
class ModelArguments(ModelConfig):
model_name_or_path: Optional[str] = field(default="meta-llama/Llama-2-7b-hf")


@dataclass
class TrainingArguments(DPOConfig):
use_liger: Optional[bool] = field(default=False)
prompt_template_version: str = field(
default="v2", metadata={"help": "choose prompt template to use for training"}
)


@dataclass
class DataArguments:
train_data_path: str = field(
default=None, metadata={"help": "Path to the training data."}
)
eval_data_path: str = field(
default=None, metadata={"help": "Path to the eval data."}
)


def trainer_save_model_safe(trainer: transformers.Trainer):
"""Saves the model in fsdp.FULL_STATE_DICT mode to have the model weights
in .bin file format which is loadable by HF Transformers"""
if trainer.accelerator.state.fsdp_plugin.state_dict_type.name != "FULL_STATE_DICT":
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model()


def main():
argument_parser = transformers.HfArgumentParser(
(DataArguments, TrainingArguments, ModelArguments)
)
data_args, training_args, model_args = argument_parser.parse_args_into_dataclasses()
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

if LOCAL_RANK == 0:
if not os.path.exists(training_args.output_dir):
os.mkdir(training_args.output_dir)

tokenizer.save_pretrained(training_args.output_dir)

train_ds = get_dataset_from_jsonl(
data_args.train_data_path, training_args.prompt_template_version
)
dev_ds = get_dataset_from_jsonl(
data_args.eval_data_path, training_args.prompt_template_version
)

print_rank0(f"train_ds: {len(train_ds)}")
print_rank0(f"dev_ds: {len(dev_ds)}")

quantization_config = get_quantization_config(model_args)
device_string = "cuda:" + str(LOCAL_RANK)
device_map = (
get_kbit_device_map()
if quantization_config is not None
else {"": device_string}
)
if len(training_args.fsdp) > 0 or is_deepspeed_zero3_enabled():
device_map = None

model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=device_map,
quantization_config=quantization_config,
)

if training_args.use_liger:
from liger_kernel.transformers import AutoLigerKernelForCausalLM

model_class = AutoLigerKernelForCausalLM
else:
model_class = transformers.AutoModelForCausalLM

model = model_class.from_pretrained(model_args.model_name_or_path, **model_kwargs)
model.resize_token_embeddings(len(tokenizer))
peft_config = get_peft_config(model_args)
ref_model = None
if is_deepspeed_zero3_enabled():
if peft_config is None:
ref_model = model_class.from_pretrained(
model_args.model_name_or_path, **model_kwargs
)

trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
train_dataset=train_ds,
eval_dataset=dev_ds,
processing_class=tokenizer,
peft_config=peft_config,
)
trainer.train()

# FSDP requires state_dict_type=FULL_STATE_DICT in order to save the model weights in .bin format
if trainer.is_fsdp_enabled:
trainer_save_model_safe(trainer=trainer)
else:
trainer.save_model()


if __name__ == "__main__":
main()