Skip to content

Commit 92b19c1

Browse files
committed
update t5 train model.
1 parent 188c41a commit 92b19c1

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

pycorrector/t5/train.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,12 @@ def load(self):
149149
def train():
150150
args = parse_args()
151151
args_dict = {
152-
"model_name_or_path": 'google/byt5-small',
152+
"model_name_or_path": 'Langboat/mengzi-t5-base',
153153
"max_len": 128,
154-
"output_dir": os.path.join(args.save_dir, './byt5-small-chinese-correction'),
154+
"output_dir": os.path.join(args.save_dir, './mengzi-t5-base-chinese-correction'),
155155
"overwrite_output_dir": True,
156-
"per_device_train_batch_size": 32,
157-
"per_device_eval_batch_size": 32,
156+
"per_device_train_batch_size": 64,
157+
"per_device_eval_batch_size": 64,
158158
"gradient_accumulation_steps": 4,
159159
"learning_rate": 5e-4,
160160
"warmup_steps": 250,
@@ -166,7 +166,7 @@ def train():
166166
"do_eval": True,
167167
"fp16": False,
168168
"use_cache": False,
169-
"max_steps": 5000, # default 5000
169+
# "max_steps": 5000,
170170
}
171171
parser = HfArgumentParser(
172172
(ModelArguments, DataTrainingArguments, TrainingArguments))

0 commit comments

Comments
 (0)