-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathppo.py
More file actions
84 lines (70 loc) · 2.87 KB
/
Copy pathppo.py
File metadata and controls
84 lines (70 loc) · 2.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from accelerate import PartialState
from transformers import AutoTokenizer, AutoModelForCausalLM, HfArgumentParser, BitsAndBytesConfig, GPTQConfig
from trl import PPOTrainer, ScriptArguments, PPOConfig, ModelConfig, get_peft_config
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
import shutil
import torch
import gc
from fine_tuning.ult_ttt_dataset import get_ttt_dataset
from fine_tuning.ult_ttt_reward import TTTReward
from fine_tuning.c_dataset import get_c_dataset
from fine_tuning.c_reward import CReward
from fine_tuning.cc_dataset import get_cc_dataset
from fine_tuning.cc_reward import CCReward
from scripts.args import GeneralArguments
def prepare_dataset(dataset, tokenizer):
def tokenize_function(examples):
return tokenizer(examples["query"], padding=True, truncation=True)
return dataset.map(tokenize_function, batched=True, remove_columns=['query'])
def main():
torch.cuda.empty_cache()
gc.collect()
parser = HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig, GeneralArguments))
_, training_args, model_args, general_args = parser.parse_args_into_dataclasses()
shutil.rmtree(training_args.output_dir, ignore_errors=True)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, padding_side="left")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
model_kwargs = dict(
device_map="auto",
trust_remote_code=model_args.trust_remote_code,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
match general_args.game:
case "ult-ttt":
dataset = get_ttt_dataset()
reward_model = TTTReward(tokenizer)
value_model = TTTReward(tokenizer)
case "connect-4":
dataset = get_c_dataset()
reward_model = CReward(tokenizer)
value_model = CReward(tokenizer)
case "xiangqi":
dataset = get_cc_dataset()
reward_model = CCReward(tokenizer)
value_model = CCReward(tokenizer)
case _:
raise ValueError(f"Unsupported game: {general_args.game}")
with PartialState().local_main_process_first():
dataset = prepare_dataset(dataset, tokenizer)
peft_config = get_peft_config(model_args)
trainer = PPOTrainer(
args=training_args,
processing_class=tokenizer,
model=model,
reward_model=reward_model,
value_model=value_model,
ref_model=None,
train_dataset=dataset,
eval_dataset=dataset,
peft_config=peft_config,
)
trainer.train()
trainer.save_model(training_args.output_dir)
trainer.generate_completions()
if __name__ == "__main__":
main()