From 8badf5ebafebb5bf53b62cec1c7b6dfc7af350cb Mon Sep 17 00:00:00 2001 From: shibing624 Date: Fri, 17 Jun 2022 16:08:01 +0800 Subject: [PATCH] update t5 . --- README.md | 20 +++++++++++--------- examples/evaluate_models.py | 2 +- pycorrector/utils/eval.py | 9 ++++++--- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 52a324f0..a3635c87 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,7 @@ PS: * ELECTRA模型:斯坦福和谷歌联合提出的一种更具效率的预训练模型,学习文本上下文表示优于同等计算资源的BERT和XLNet * ERNIE模型:百度提出的基于知识增强的语义表示模型,有可适配中文的强大语义表征能力。在情感分析、文本匹配、自然语言推理、词法分析、阅读理解、智能问答等16个公开数据集上超越世界领先技术 * MacBERT模型:使用全词掩蔽和N-Gram掩蔽策略适配中文表达,和通过用其相似的单词来掩盖单词,相较BERT缩小了训练前和微调阶段之间的差距,加入错误检测和纠正网络端到端纠正文本拼写错误 +* T5模型:基于Transformer的Encoder-Decoder模型,把所有NLP任务转化为Text-to-text任务的统一框架,在超大数据集上得到预训练大模型,在各任务取得SOTA效果。 # Demo @@ -105,12 +106,11 @@ python examples/gradio_demo.py # Evaluation -提供评估脚本[pycorrector/utils/eval.py](./pycorrector/utils/eval.py) -和评估执行脚本[examples/evaluate_models.py](./examples/evaluate_models.py),该脚本有两个功能: +提供评估脚本[examples/evaluate_models.py](./examples/evaluate_models.py): -- sighan15评估集:SIGHAN2015的测试集[pycorrector/data/cn/sighan_2015/test.tsv](pycorrector/data/cn/sighan_2015/test.tsv) +- 使用sighan15评估集:SIGHAN2015的测试集[pycorrector/data/cn/sighan_2015/test.tsv](pycorrector/data/cn/sighan_2015/test.tsv) ,已经转为简体中文。 -- 评估纠错准召率:采用严格句子粒度(Sentence Level)计算方式,把模型纠正之后的与正确句子完成相同的视为正确,否则为错。 +- 评估标准:纠错准召率,采用严格句子粒度(Sentence Level)计算方式,把模型纠正之后的与正确句子完成相同的视为正确,否则为错。 ### 评估结果 @@ -118,16 +118,18 @@ GPU:Tesla V100,显存 32 GB | 数据集 | 模型 | Backbone | GPU | Precision | Recall | F1 | QPS | | :---------: | :---------: | :---------: | :------: | :---------: | :---------: | :---------: | :---------: | -| Sighan_15 | Rule | kenlm | cpu | 0.6860 | 0.1529 | 0.2500 | 9 | -| Sighan_15 | BERT | bert-base-chinese + MLM | gpu | 0.8029 | 0.4052 | 0.5386 | 2 | -| Sighan_15 | T5 | byt5-small | gpu | 0.5220 | 0.3941 | 0.4491 | 111 | -| Sighan_15 | Seq2Seq | convseq2seq | gpu | 0.2415 | 0.1436 | 0.1801 | 6 | -| **Sighan_15** | **MacBert** | **macbert4csc-base-chinese** | **gpu** | **0.8254** | **0.7311** | **0.7754** | **224** | +| Sighan_15 | rule(pycorrector.correct) | kenlm | cpu | 0.6860 | 0.1529 | 0.2500 | 9 | +| Sighan_15 | bert-correction | bert-base-chinese + MLM | gpu | 0.8029 | 0.4052 | 0.5386 | 2 | +| Sighan_15 | t5 | byt5-small | gpu | 0.5220 | 0.3941 | 0.4491 | 111 | +| Sighan_15 | mengzi-t5-base-chinese-correction | mengzi-t5-base | gpu | 0.8321 | 0.6390 | 0.7229 | 214 | +| Sighan_15 | convseq2seq-chinese-correction | convseq2seq | gpu | 0.2415 | 0.1436 | 0.1801 | 6 | +| **Sighan_15** | **macbert4csc-base-chinese** | **macbert-base-chinese** | **gpu** | **0.8254** | **0.7311** | **0.7754** | **224** | ### 结论 - 中文拼写纠错模型效果最好的是**macbert**,模型名称是*shibing624/macbert4csc-base-chinese* - 中文语法纠错模型效果最好的是**seq2seq**,模型名称是*convseq2seq* +- 最具潜力的模型是**T5**,未改变模型结构,仅fine-tune中文纠错数据集,已经在`SIGHAN 2015`取得接近SOTA的效果 # Install diff --git a/examples/evaluate_models.py b/examples/evaluate_models.py index d90e8986..cbab9f3c 100644 --- a/examples/evaluate_models.py +++ b/examples/evaluate_models.py @@ -40,7 +40,7 @@ def main(args): from pycorrector.t5.t5_corrector import T5Corrector model = T5Corrector() eval.eval_sighan2015_by_model_batch(model.batch_t5_correct) - # Sentence Level: acc:0.5227, precision:0.5220, recall:0.3941, f1:0.4491, cost time:9.88 s + # Sentence Level: acc:0.7582, precision:0.8321, recall:0.6390, f1:0.7229, cost time:5.12 s if args.data == 'sighan_15' and args.model == 'convseq2seq': from pycorrector.seq2seq.seq2seq_corrector import Seq2SeqCorrector model = Seq2SeqCorrector() diff --git a/pycorrector/utils/eval.py b/pycorrector/utils/eval.py index d30f5881..4cca14b4 100644 --- a/pycorrector/utils/eval.py +++ b/pycorrector/utils/eval.py @@ -246,7 +246,8 @@ def eval_corpus500_by_model(correct_fn, input_eval_path=eval_data_path, verbose= recall = TP / (TP + FN) if TP > 0 else 0.0 f1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0 print( - f'Sentence Level: acc:{acc:.4f}, precision:{precision:.4f}, recall:{recall:.4f}, f1:{f1:.4f}, cost time:{spend_time:.2f} s') + f'Sentence Level: acc:{acc:.4f}, precision:{precision:.4f}, recall:{recall:.4f}, f1:{f1:.4f}, ' + f'cost time:{spend_time:.2f} s, total num: {total_num}') return acc, precision, recall, f1 @@ -313,7 +314,8 @@ def eval_sighan2015_by_model(correct_fn, sighan_path=sighan_2015_path, verbose=T recall = TP / (TP + FN) if TP > 0 else 0.0 f1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0 print( - f'Sentence Level: acc:{acc:.4f}, precision:{precision:.4f}, recall:{recall:.4f}, f1:{f1:.4f}, cost time:{spend_time:.2f} s') + f'Sentence Level: acc:{acc:.4f}, precision:{precision:.4f}, recall:{recall:.4f}, f1:{f1:.4f}, ' + f'cost time:{spend_time:.2f} s, total num: {total_num}') return acc, precision, recall, f1 @@ -386,7 +388,8 @@ def eval_sighan2015_by_model_batch(correct_fn, sighan_path=sighan_2015_path, ver recall = TP / (TP + FN) if TP > 0 else 0.0 f1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0 print( - f'Sentence Level: acc:{acc:.4f}, precision:{precision:.4f}, recall:{recall:.4f}, f1:{f1:.4f}, cost time:{spend_time:.2f} s') + f'Sentence Level: acc:{acc:.4f}, precision:{precision:.4f}, recall:{recall:.4f}, f1:{f1:.4f}, ' + f'cost time:{spend_time:.2f} s, total num: {total_num}') return acc, precision, recall, f1