-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathprepare_ultrachat.py
143 lines (125 loc) · 4.59 KB
/
prepare_ultrachat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import argparse
import os
from transformers import AutoTokenizer
from datasets import load_dataset
from litdata import optimize
from functools import partial
MODELS = {
"phi-3": "microsoft/Phi-3-mini-4k-instruct",
"llama-2": "meta-llama/Llama-2-7b-chat-hf",
"stablelm": "stabilityai/stablelm-zephyr-3b",
}
MAX_LENGTH = 2048
MAX_OUTPUT_LENGTH = 512
def tokenize(example, tokenizer):
column = "messages" if "messages" in example else "chosen"
text = tokenizer.apply_chat_template(
example[column], tokenize=False, add_generation_prompt=False
)
messages = text.split(generation_prompt)
input_text = generation_prompt.join(messages[:-1]) + generation_prompt
output_text = messages[-1]
input_ids = tokenizer(text, return_tensors="pt").input_ids
res = {"model_inputs": {"input_ids": input_ids, "labels": input_ids.clone()}}
gen_input_ids = tokenizer(input_text, return_tensors="pt").input_ids
res["model_inputs_gen"] = {"input_ids": gen_input_ids}
res["response"] = output_text
return res
def filter_length(example, max_input_len, max_output_len):
max_length = max_input_len + max_output_len
if example["model_inputs"]["input_ids"].size(1) > max_length:
return False
if example["model_inputs_gen"]["input_ids"].size(1) > max_input_len:
return False
output_tokens = tokenizer(example["response"], return_tensors="pt").input_ids
if output_tokens.size(1) > max_output_len:
return False
return True
def fn(index, data):
yield data[index]
def prepare_train(args, tokenizer):
dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
column_names = list(dataset.features)
dataset = dataset.map(
tokenize,
fn_kwargs={"tokenizer": tokenizer},
num_proc=args.num_proc,
desc="Applying chat template",
remove_columns=column_names,
)
dataset = dataset.with_format("torch")
dataset = dataset.filter(
filter_length,
fn_kwargs={
"max_input_len": MAX_LENGTH - MAX_OUTPUT_LENGTH,
"max_output_len": MAX_OUTPUT_LENGTH,
},
num_proc=args.num_proc,
)
os.makedirs(args.output_dir, exist_ok=True)
optimize(
fn=partial(fn, data=dataset),
inputs=list(range(len(dataset))),
output_dir=os.path.join(args.output_dir, args.model_type, "train"),
num_workers=16,
chunk_bytes="500MB",
)
def prepare_test(args, tokenizer):
dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split="test_sft")
column_names = list(dataset.features)
dataset = dataset.map(
tokenize,
fn_kwargs={"tokenizer": tokenizer},
num_proc=args.num_proc,
desc="Applying chat template",
remove_columns=column_names,
)
dataset = dataset.with_format("torch")
dataset = dataset.filter(
filter_length,
fn_kwargs={
"max_input_len": MAX_LENGTH - MAX_OUTPUT_LENGTH,
"max_output_len": MAX_OUTPUT_LENGTH,
},
num_proc=args.num_proc,
)
ds = dataset.train_test_split(test_size=2000, seed=42, shuffle=True)
dataset = ds["test"]
os.makedirs(args.output_dir, exist_ok=True)
optimize(
fn=partial(fn, data=dataset),
inputs=list(range(len(dataset))),
output_dir=os.path.join(args.output_dir, args.model_type, "test"),
num_workers=2,
chunk_bytes="500MB",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_type",
type=str,
choices=list(MODELS.keys()),
default="phi-3",
help="Teacher type",
)
parser.add_argument("--output_dir", type=str, default="data")
parser.add_argument(
"--num_proc", type=int, default=64, help="number of workers for processing"
)
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(MODELS[args.model_type])
if args.model_type == "phi-3":
# https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/blob/main/sample_finetune.py#L141
tokenizer.pad_token = (
tokenizer.unk_token
) # use unk rather than eos token to prevent endless generation
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
tokenizer.padding_side = "right"
if args.model_type in ["phi-3", "stablelm"]:
generation_prompt = "<|assistant|>\n"
elif args.model_type in ["llama-2"]:
generation_prompt = " [/INST] "
else:
raise NotImplementedError(args.model_type)
prepare_train(args, tokenizer)
prepare_test(args, tokenizer)