Skip to content

Commit

Permalink
update t5 with mengzi-t5 infer.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Jun 17, 2022
1 parent 7ad3b5d commit 81996d7
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 9 deletions.
25 changes: 19 additions & 6 deletions pycorrector/t5/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,25 @@ def parse_args():


def predict():
example_sentences = ["我跟我朋唷打算去法国玩儿。",
"少先队员因该为老人让坐。",
"我们是新时代的接斑人",
"我咪路,你能给我指路吗?",
"他带了黑色的包,也带了照像机",
]
example_sentences = [
"我跟我朋唷打算去法国玩儿。",
"少先队员因该为老人让坐。",
"我们是新时代的接斑人",
"我咪路,你能给我指路吗?",
"他带了黑色的包,也带了照像机",
'因为爸爸在看录音机,所以我没得看',
'不过在许多传统国家,女人向未得到平等',
'妈妈说:"别趴地上了,快起来,你还吃饭吗?",我说:"好。"就扒起来了。',
'你说:“怎么办?”我怎么知道?',
'我父母们常常说:“那时候吃的东西太少,每天只能吃一顿饭。”想一想,人们都快要饿死,谁提出化肥和农药的污染。',
'这本新书《居里夫人传》将的很生动有趣',
'֍我喜欢吃鸡,公鸡、母鸡、白切鸡、乌鸡、紫燕鸡……֍新的食谱',
'注意:“跨类保护”不等于“全类保护”。',
'12.——对比文件中未公开的数值和对比文件中已经公开的中间值具有新颖性;',
'《著作权法》(2020修正)第23条:“自然人的作品,其发表权、本法第',
'三步检验法(三步检验标准)(three-step test):若要',
'三步检验法“三步‘检验’标准”(three-step test):若要',
]
args = parse_args()
model_dir = args.save_dir
tokenizer = AutoTokenizer.from_pretrained(model_dir)
Expand Down
11 changes: 8 additions & 3 deletions pycorrector/t5/t5_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
from pycorrector.utils.logger import logger
from pycorrector import config
from pycorrector.utils.tokenizer import split_text_by_maxlen
from pycorrector.utils.text_utils import is_chinese

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
unk_tokens = [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤', '\t', ', '玕', '']
unk_tokens = [' ', '“', '”', '‘', '’', '琊', '\n', '…', '擤', '\t', '玕', '']


def get_errors(corrected_text, origin_text):
Expand All @@ -31,10 +32,13 @@ def get_errors(corrected_text, origin_text):
corrected_text = corrected_text[:i] + ori_char + corrected_text[i:]
continue
if ori_char != corrected_text[i]:
if ori_char.lower() == corrected_text[i]:
# pass english upper char
if not is_chinese(ori_char):
# pass not chinese char
corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
continue
if not is_chinese(corrected_text[i]):
corrected_text = corrected_text[:i] + corrected_text[i + 1:]
continue
sub_details.append((ori_char, corrected_text[i], i, i + 1))
sub_details = sorted(sub_details, key=operator.itemgetter(2))
return corrected_text, sub_details
Expand Down Expand Up @@ -132,6 +136,7 @@ def batch_t5_correct(self, texts: List[str], max_length: int = 128):
'12.——对比文件中未公开的数值和对比文件中已经公开的中间值具有新颖性;',
'《著作权法》(2020修正)第23条:“自然人的作品,其发表权、本法第',
'三步检验法(三步检验标准)(three-step test):若要',
'三步检验法“三步‘检验’标准”(three-step test):若要',
]
t1 = time.time()
for sent in error_sentences:
Expand Down

0 comments on commit 81996d7

Please sign in to comment.