forked from PaddlePaddle/PaddleNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
228 lines (188 loc) · 9.77 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import argparse
import os
import sys
import random
import time
from scipy import stats
import numpy as np
import paddle
import paddle.nn.functional as F
import paddlenlp as ppnlp
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.datasets import load_dataset
from paddlenlp.transformers import LinearDecayWithWarmup
from model import SimCSE
from data import read_simcse_text, read_text_pair, convert_example, create_dataloader
from data import word_repetition
# yapf: disable
parser = argparse.ArgumentParser()
parser.add_argument("--save_dir", default='./checkpoint', type=str, help="The output directory where the model checkpoints will be written.")
parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after tokenization."
"Sequences longer than this will be truncated, sequences shorter will be padded.")
parser.add_argument("--batch_size", default=32, type=int, help="Batch size per GPU/CPU for training.")
parser.add_argument("--output_emb_size", default=0, type=int, help="Output_embedding_size, 0 means use hidden_size as output embedding size.")
parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--epochs", default=1, type=int, help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion", default=0.0, type=float, help="Linear warmup proption over the training process.")
parser.add_argument("--init_from_ckpt", type=str, default=None, help="The path of checkpoint to be loaded.")
parser.add_argument("--seed", type=int, default=1000, help="Random seed for initialization.")
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to train model, defaults to gpu.")
parser.add_argument('--save_steps', type=int, default=10000, help="Step interval for saving checkpoint.")
parser.add_argument("--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override ecpochs.")
parser.add_argument('--eval_steps', type=int, default=10000, help="Step interval for evaluation.")
parser.add_argument("--train_set_file", type=str, required=True, help="The full path of train_set_file.")
parser.add_argument("--test_set_file", type=str, required=True, help="The full path of test_set_file.")
parser.add_argument("--margin", default=0.0, type=float, help="Margin beteween pos_sample and neg_samples.")
parser.add_argument("--scale", default=20, type=int, help="Scale for pair-wise margin_rank_loss.")
parser.add_argument("--dropout", default=0.1, type=float, help="Dropout for pretrained model encoder.")
parser.add_argument("--dup_rate", default=0.32, type=float, help="duplicate rate for word reptition.")
parser.add_argument("--infer_with_fc_pooler", action='store_true', help="Whether use fc layer after cls embedding or not for when infer.")
args = parser.parse_args()
def set_seed(seed):
"""sets random seed"""
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
def do_evaluate(model, tokenizer, data_loader, with_pooler=False):
model.eval()
total_num = 0
spearman_corr = 0.0
sims = []
labels = []
for batch in data_loader:
query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids, label = batch
total_num += len(label)
query_cls_embedding = model.get_pooled_embedding(
query_input_ids, query_token_type_ids, with_pooler=with_pooler)
title_cls_embedding = model.get_pooled_embedding(title_input_ids, title_token_type_ids, with_pooler=with_pooler)
cosine_sim = paddle.sum(query_cls_embedding * title_cls_embedding, axis=-1)
sims.append(cosine_sim.numpy())
labels.append(label.numpy())
sims = np.concatenate(sims, axis=0)
labels = np.concatenate(labels, axis=0)
spearman_corr = stats.spearmanr(labels, sims).correlation
model.train()
return spearman_corr, total_num
def do_train():
paddle.set_device(args.device)
rank = paddle.distributed.get_rank()
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()
set_seed(args.seed)
train_ds = load_dataset(
read_simcse_text, data_path=args.train_set_file, lazy=False)
dev_ds = load_dataset(
read_text_pair, data_path=args.test_set_file, lazy=False)
pretrained_model = ppnlp.transformers.ErnieModel.from_pretrained(
'ernie-1.0',
hidden_dropout_prob=args.dropout,
attention_probs_dropout_prob=args.dropout)
tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained('ernie-1.0')
trans_func = partial(
convert_example,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length)
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # query_input
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # query_segment
Pad(axis=0, pad_val=tokenizer.pad_token_id), # title_input
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # tilte_segment
): [data for data in fn(samples)]
dev_batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # query_input
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # query_segment
Pad(axis=0, pad_val=tokenizer.pad_token_id), # title_input
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # tilte_segment
Stack(dtype="int64"), # labels
): [data for data in fn(samples)]
train_data_loader = create_dataloader(
train_ds,
mode='train',
batch_size=args.batch_size,
batchify_fn=batchify_fn,
trans_fn=trans_func)
dev_data_loader = create_dataloader(
dev_ds,
mode='eval',
batch_size=args.batch_size,
batchify_fn=dev_batchify_fn,
trans_fn=trans_func)
model = SimCSE(
pretrained_model,
margin=args.margin,
scale=args.scale,
output_emb_size=args.output_emb_size)
if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt):
state_dict = paddle.load(args.init_from_ckpt)
model.set_dict(state_dict)
print("warmup from:{}".format(args.init_from_ckpt))
model = paddle.DataParallel(model)
num_training_steps = args.max_steps if args.max_steps > 0 else len(
train_data_loader) * args.epochs
lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps,
args.warmup_proportion)
# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params = [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler,
parameters=model.parameters(),
weight_decay=args.weight_decay,
apply_decay_param_fun=lambda x: x in decay_params)
global_step = 0
tic_train = time.time()
for epoch in range(1, args.epochs + 1):
for step, batch in enumerate(train_data_loader, start=1):
query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids = batch
if(args.dup_rate > 0):
query_input_ids,query_token_type_ids=word_repetition(query_input_ids,query_token_type_ids,args.dup_rate)
title_input_ids,title_token_type_ids=word_repetition(title_input_ids,title_token_type_ids,args.dup_rate)
loss = model(
query_input_ids=query_input_ids,
title_input_ids=title_input_ids,
query_token_type_ids=query_token_type_ids,
title_token_type_ids=title_token_type_ids)
global_step += 1
if global_step % 10 == 0 and rank == 0:
print(
"global step %d, epoch: %d, batch: %d, loss: %.5f, speed: %.2f step/s"
% (global_step, epoch, step, loss,
10 / (time.time() - tic_train)))
tic_train = time.time()
if global_step % args.eval_steps == 0 and rank == 0:
# need better way to get model Layers
spearman_corr, total_num = do_evaluate(model._layers, tokenizer, dev_data_loader, args.infer_with_fc_pooler)
print("global step: {}, spearman_corr: {:.4f}, total_num: {}".format(global_step, spearman_corr, total_num))
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.clear_grad()
if global_step % args.save_steps == 0 and rank == 0:
save_dir = os.path.join(args.save_dir, "model_%d" % global_step)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_param_path = os.path.join(save_dir, 'model_state.pdparams')
paddle.save(model.state_dict(), save_param_path)
tokenizer.save_pretrained(save_dir)
if args.max_steps > 0 and global_step >= args.max_steps:
return
if __name__ == "__main__":
do_train()