Skip to content

Commit

Permalink
update t5 .
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Jun 17, 2022
1 parent 81996d7 commit 8badf5e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
20 changes: 11 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ PS:
* ELECTRA模型:斯坦福和谷歌联合提出的一种更具效率的预训练模型,学习文本上下文表示优于同等计算资源的BERT和XLNet
* ERNIE模型:百度提出的基于知识增强的语义表示模型,有可适配中文的强大语义表征能力。在情感分析、文本匹配、自然语言推理、词法分析、阅读理解、智能问答等16个公开数据集上超越世界领先技术
* MacBERT模型:使用全词掩蔽和N-Gram掩蔽策略适配中文表达,和通过用其相似的单词来掩盖单词,相较BERT缩小了训练前和微调阶段之间的差距,加入错误检测和纠正网络端到端纠正文本拼写错误
* T5模型:基于Transformer的Encoder-Decoder模型,把所有NLP任务转化为Text-to-text任务的统一框架,在超大数据集上得到预训练大模型,在各任务取得SOTA效果。

# Demo

Expand All @@ -105,29 +106,30 @@ python examples/gradio_demo.py

# Evaluation

提供评估脚本[pycorrector/utils/eval.py](./pycorrector/utils/eval.py)
和评估执行脚本[examples/evaluate_models.py](./examples/evaluate_models.py),该脚本有两个功能:
提供评估脚本[examples/evaluate_models.py](./examples/evaluate_models.py)

- sighan15评估集:SIGHAN2015的测试集[pycorrector/data/cn/sighan_2015/test.tsv](pycorrector/data/cn/sighan_2015/test.tsv)
- 使用sighan15评估集:SIGHAN2015的测试集[pycorrector/data/cn/sighan_2015/test.tsv](pycorrector/data/cn/sighan_2015/test.tsv)
,已经转为简体中文。
- 评估纠错准召率:采用严格句子粒度(Sentence Level)计算方式,把模型纠正之后的与正确句子完成相同的视为正确,否则为错。
- 评估标准:纠错准召率,采用严格句子粒度(Sentence Level)计算方式,把模型纠正之后的与正确句子完成相同的视为正确,否则为错。

### 评估结果

GPU:Tesla V100,显存 32 GB

| 数据集 | 模型 | Backbone | GPU | Precision | Recall | F1 | QPS |
| :---------: | :---------: | :---------: | :------: | :---------: | :---------: | :---------: | :---------: |
| Sighan_15 | Rule | kenlm | cpu | 0.6860 | 0.1529 | 0.2500 | 9 |
| Sighan_15 | BERT | bert-base-chinese + MLM | gpu | 0.8029 | 0.4052 | 0.5386 | 2 |
| Sighan_15 | T5 | byt5-small | gpu | 0.5220 | 0.3941 | 0.4491 | 111 |
| Sighan_15 | Seq2Seq | convseq2seq | gpu | 0.2415 | 0.1436 | 0.1801 | 6 |
| **Sighan_15** | **MacBert** | **macbert4csc-base-chinese** | **gpu** | **0.8254** | **0.7311** | **0.7754** | **224** |
| Sighan_15 | rule(pycorrector.correct) | kenlm | cpu | 0.6860 | 0.1529 | 0.2500 | 9 |
| Sighan_15 | bert-correction | bert-base-chinese + MLM | gpu | 0.8029 | 0.4052 | 0.5386 | 2 |
| Sighan_15 | t5 | byt5-small | gpu | 0.5220 | 0.3941 | 0.4491 | 111 |
| Sighan_15 | mengzi-t5-base-chinese-correction | mengzi-t5-base | gpu | 0.8321 | 0.6390 | 0.7229 | 214 |
| Sighan_15 | convseq2seq-chinese-correction | convseq2seq | gpu | 0.2415 | 0.1436 | 0.1801 | 6 |
| **Sighan_15** | **macbert4csc-base-chinese** | **macbert-base-chinese** | **gpu** | **0.8254** | **0.7311** | **0.7754** | **224** |

### 结论

- 中文拼写纠错模型效果最好的是**macbert**,模型名称是*shibing624/macbert4csc-base-chinese*
- 中文语法纠错模型效果最好的是**seq2seq**,模型名称是*convseq2seq*
- 最具潜力的模型是**T5**,未改变模型结构,仅fine-tune中文纠错数据集,已经在`SIGHAN 2015`取得接近SOTA的效果

# Install

Expand Down
2 changes: 1 addition & 1 deletion examples/evaluate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def main(args):
from pycorrector.t5.t5_corrector import T5Corrector
model = T5Corrector()
eval.eval_sighan2015_by_model_batch(model.batch_t5_correct)
# Sentence Level: acc:0.5227, precision:0.5220, recall:0.3941, f1:0.4491, cost time:9.88 s
# Sentence Level: acc:0.7582, precision:0.8321, recall:0.6390, f1:0.7229, cost time:5.12 s
if args.data == 'sighan_15' and args.model == 'convseq2seq':
from pycorrector.seq2seq.seq2seq_corrector import Seq2SeqCorrector
model = Seq2SeqCorrector()
Expand Down
9 changes: 6 additions & 3 deletions pycorrector/utils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ def eval_corpus500_by_model(correct_fn, input_eval_path=eval_data_path, verbose=
recall = TP / (TP + FN) if TP > 0 else 0.0
f1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0
print(
f'Sentence Level: acc:{acc:.4f}, precision:{precision:.4f}, recall:{recall:.4f}, f1:{f1:.4f}, cost time:{spend_time:.2f} s')
f'Sentence Level: acc:{acc:.4f}, precision:{precision:.4f}, recall:{recall:.4f}, f1:{f1:.4f}, '
f'cost time:{spend_time:.2f} s, total num: {total_num}')
return acc, precision, recall, f1


Expand Down Expand Up @@ -313,7 +314,8 @@ def eval_sighan2015_by_model(correct_fn, sighan_path=sighan_2015_path, verbose=T
recall = TP / (TP + FN) if TP > 0 else 0.0
f1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0
print(
f'Sentence Level: acc:{acc:.4f}, precision:{precision:.4f}, recall:{recall:.4f}, f1:{f1:.4f}, cost time:{spend_time:.2f} s')
f'Sentence Level: acc:{acc:.4f}, precision:{precision:.4f}, recall:{recall:.4f}, f1:{f1:.4f}, '
f'cost time:{spend_time:.2f} s, total num: {total_num}')
return acc, precision, recall, f1


Expand Down Expand Up @@ -386,7 +388,8 @@ def eval_sighan2015_by_model_batch(correct_fn, sighan_path=sighan_2015_path, ver
recall = TP / (TP + FN) if TP > 0 else 0.0
f1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0
print(
f'Sentence Level: acc:{acc:.4f}, precision:{precision:.4f}, recall:{recall:.4f}, f1:{f1:.4f}, cost time:{spend_time:.2f} s')
f'Sentence Level: acc:{acc:.4f}, precision:{precision:.4f}, recall:{recall:.4f}, f1:{f1:.4f}, '
f'cost time:{spend_time:.2f} s, total num: {total_num}')
return acc, precision, recall, f1


Expand Down

0 comments on commit 8badf5e

Please sign in to comment.