Skip to content

Commit

Permalink
update t5 with mengzi-t5.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Jun 16, 2022
1 parent 92b19c1 commit 7ad3b5d
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 32 deletions.
2 changes: 1 addition & 1 deletion pycorrector/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
macbert_model_dir = os.path.join(USER_DATA_DIR, 'macbert_models/chinese_finetuned_correction/')
os.makedirs(macbert_model_dir, exist_ok=True)
# t5模型文件路径
t5_model_dir = os.path.join(USER_DATA_DIR, 't5_models/byt5-small-chinese-correction/')
t5_model_dir = os.path.join(USER_DATA_DIR, 't5_models/mengzi-t5-base-chinese-correction/')
os.makedirs(t5_model_dir, exist_ok=True)
# convseq2seq模型文件夹路径
convseq2seq_model_dir = os.path.join(USER_DATA_DIR, 'seq2seq_models/convseq2seq_correction/')
Expand Down
4 changes: 2 additions & 2 deletions pycorrector/t5/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--save_dir', type=str, default='output', help='save dir')
parser.add_argument('--save_dir', type=str, default='output/mengzi-t5-base-chinese-correction/', help='save dir')
args = parser.parse_args()
return args

Expand All @@ -26,7 +26,7 @@ def predict():
"他带了黑色的包,也带了照像机",
]
args = parse_args()
model_dir = os.path.join(args.save_dir, './byt5-small-chinese-correction')
model_dir = args.save_dir
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = T5ForConditionalGeneration.from_pretrained(model_dir)
model.to(device)
Expand Down
8 changes: 4 additions & 4 deletions pycorrector/t5/t5_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,17 @@ def get_errors(corrected_text, origin_text):

class T5Corrector(object):
def __init__(self, model_dir=config.t5_model_dir):
self.name = 'byt5_corrector'
self.name = 't5_corrector'
t1 = time.time()
bin_path = os.path.join(model_dir, 'pytorch_model.bin')
if not os.path.exists(bin_path):
model_dir = "shibing624/byt5-small-chinese-correction"
model_dir = "shibing624/mengzi-t5-base-chinese-correction"
logger.warning(f'local model {bin_path} not exists, use default HF model {model_dir}')
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = T5ForConditionalGeneration.from_pretrained(model_dir)
self.model.to(device)
logger.debug("Use device: {}".format(device))
logger.debug('Loaded byt5 correction model: %s, spend: %.3f s.' % (model_dir, time.time() - t1))
logger.debug('Loaded t5 correction model: %s, spend: %.3f s.' % (model_dir, time.time() - t1))

def t5_correct(self, text: str, max_length: int = 128):
"""
Expand Down Expand Up @@ -107,7 +107,7 @@ def batch_t5_correct(self, texts: List[str], max_length: int = 128):


if __name__ == "__main__":
m = T5Corrector('./output/byt5-small-chinese-correction/')
m = T5Corrector('./output/mengzi-t5-base-chinese-correction/')
error_sentences = [
'少先队员因该为老人让坐',
'少 先 队 员 因 该 为 老人让坐',
Expand Down
56 changes: 33 additions & 23 deletions pycorrector/t5/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from transformers import AutoTokenizer, T5ForConditionalGeneration
from transformers import HfArgumentParser, TrainingArguments, Trainer, set_seed
from datasets import load_dataset, Dataset
from loguru import logger

os.environ["TOKENIZERS_PARALLELISM"] = "FALSE"
pwd_path = os.path.abspath(os.path.dirname(__file__))


Expand Down Expand Up @@ -123,13 +125,21 @@ class ModelArguments:

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--train_path', type=str,
default=os.path.join(pwd_path, '../data/cn/sighan_2015/train.tsv'),
help='SIGHAN dataset')
parser.add_argument('--test_path', type=str,
default=os.path.join(pwd_path, '../data/cn/sighan_2015/test.tsv'),
help='SIGHAN dataset')
parser.add_argument('--save_dir', type=str, default='output', help='save dir')
parser.add_argument('--train_path', type=str, default=os.path.join(pwd_path, '../data/cn/sighan_2015/train.tsv'),
help='train dataset')
parser.add_argument('--test_path', type=str, default=os.path.join(pwd_path, '../data/cn/sighan_2015/test.tsv'),
help='test dataset')
parser.add_argument('--save_dir', type=str, default='./output/mengzi-t5-base-chinese-correction/', help='save dir')
parser.add_argument('--model_name_or_path', type=str, default='Langboat/mengzi-t5-base', help='pretrained model')
parser.add_argument('--max_len', type=int, default=128, help='max length')
parser.add_argument('--batch_size', type=int, default=32, help='batch size')
parser.add_argument('--logging_steps', type=int, default=100, help='logging steps num')
parser.add_argument('--warmup_steps', type=int, default=200, help='logging steps num')
parser.add_argument('--eval_steps', type=int, default=250, help='eval steps num')
parser.add_argument('--epochs', type=int, default=10, help='train epochs num')
parser.add_argument('--max_steps', type=int, default=5000, help='train max steps')
parser.add_argument("--do_train", action="store_true", help="whether not to do train")
parser.add_argument("--do_eval", action="store_true", help="whether not to do eval")
args = parser.parse_args()

return args
Expand All @@ -149,32 +159,32 @@ def load(self):
def train():
args = parse_args()
args_dict = {
"model_name_or_path": 'Langboat/mengzi-t5-base',
"max_len": 128,
"output_dir": os.path.join(args.save_dir, './mengzi-t5-base-chinese-correction'),
"model_name_or_path": args.model_name_or_path,
"max_len": args.max_len,
"output_dir": args.save_dir,
"overwrite_output_dir": True,
"per_device_train_batch_size": 64,
"per_device_eval_batch_size": 64,
"per_device_train_batch_size": args.batch_size,
"per_device_eval_batch_size": args.batch_size,
"gradient_accumulation_steps": 4,
"learning_rate": 5e-4,
"warmup_steps": 250,
"logging_steps": 100,
"warmup_steps": args.warmup_steps,
"logging_steps": args.logging_steps,
"evaluation_strategy": "steps",
"eval_steps": 250,
"num_train_epochs": 4,
"do_train": True,
"do_eval": True,
"eval_steps": args.eval_steps,
"num_train_epochs": args.epochs,
"do_train": args.do_train,
"do_eval": args.do_eval,
"fp16": False,
"use_cache": False,
# "max_steps": 5000,
"max_steps": args.max_steps,
}
parser = HfArgumentParser(
(ModelArguments, DataTrainingArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_dict(args_dict)
set_seed(training_args.seed)
if args.train_path.endswith('.tsv'):
dataset = load_dataset('text', data_files={'train': [args.train_path], 'test': args.test_path})
print(dataset)
logger.info(dataset)
train_dataset = dataset['train']
valid_dataset = dataset['test']
elif args.train_path.endswith('.json'):
Expand All @@ -185,8 +195,8 @@ def train():
d = CscDataset(args.test_path)
data_dict = d.load()
valid_dataset = Dataset.from_dict(data_dict, split='test')
print(train_dataset)
print(valid_dataset)
logger.info(train_dataset)
logger.info(valid_dataset)
else:
raise ValueError('train_path must be tsv or json')

Expand All @@ -205,7 +215,7 @@ def train():
tokenizer.model_max_length = 128
model.config.max_length = 128

print('train_dataset:', train_dataset[:3])
logger.info(f'train_dataset: {train_dataset[:3]}')

def tokenize_dataset(tokenizer, dataset, max_len):
def convert_to_features(example_batch):
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
jieba>=0.39
pypinyin
numpy
six
six
loguru
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
"jieba",
"pypinyin",
"numpy",
"six"
"six",
"loguru",
],
packages=find_packages(exclude=['tests']),
package_dir={'pycorrector': 'pycorrector'},
Expand Down

0 comments on commit 7ad3b5d

Please sign in to comment.