1
1
import copy
2
2
import numpy as np
3
3
import time
4
+ import re
4
5
from onmt .inference_engine import InferenceEnginePY
5
6
import onmt .opts as opts
6
7
from onmt .utils .logging import init_logger
7
8
from onmt .utils .parse import ArgumentParser
8
9
from onmt .utils .misc import use_gpu , set_random_seed
9
10
10
11
12
+ def wikitext_detokenizer (line ):
13
+ string = line
14
+ # contractions
15
+ string = string .replace ("s '" , "s'" )
16
+ string = re .sub (r"/' [0-9]/" , r"/'[0-9]/" , string )
17
+ # number separators
18
+ string = string .replace (" @-@ " , "-" )
19
+ string = string .replace (" @,@ " , "," )
20
+ string = string .replace (" @.@ " , "." )
21
+ # punctuation
22
+ string = string .replace (" : " , ": " )
23
+ string = string .replace (" ; " , "; " )
24
+ string = string .replace (" . " , ". " )
25
+ string = string .replace (" ! " , "! " )
26
+ string = string .replace (" ? " , "? " )
27
+ string = string .replace (" , " , ", " )
28
+ # double brackets
29
+ string = re .sub (r"\(\s*([^\)]*?)\s*\)" , r"(\1)" , string )
30
+ string = re .sub (r"\[\s*([^\]]*?)\s*\]" , r"[\1]" , string )
31
+ string = re .sub (r"{\s*([^}]*?)\s*}" , r"{\1}" , string )
32
+ string = re .sub (r"\"\s*([^\"]*?)\s*\"" , r'"\1"' , string )
33
+ string = re .sub (r"'\s*([^']*?)\s*'" , r"'\1'" , string )
34
+ # miscellaneous
35
+ string = string .replace ("= = = =" , "====" )
36
+ string = string .replace ("= = =" , "===" )
37
+ string = string .replace ("= =" , "==" )
38
+ string = string .replace (" " + chr (176 ) + " " , chr (176 ))
39
+ string = string .replace (" \n " , "\n " )
40
+ string = string .replace ("\n " , "\n " )
41
+ string = string .replace (" N " , " 1 " )
42
+ string = string .replace (" 's" , "'s" )
43
+ return string
44
+
45
+
11
46
def tokenize_dataset (opt , context_length ):
12
47
print ("Tokenization..." )
13
48
# Clean and Concat the dataset
14
- x = open (opt .src , "r" ).readlines ()
15
- xx = [_x for _x in x if _x != " \n " ]
16
- from onmt .transforms .tokenize import SentencePieceTransform
49
+ xx = open (opt .src , "r" ).readlines ()
50
+ if "sentencepiece" in opt .transforms :
51
+ from onmt .transforms .tokenize import SentencePieceTransform
52
+
53
+ tokenizer = SentencePieceTransform (opt )
54
+ elif "onmt_tokenize" in opt .transforms :
55
+ from onmt .transforms .tokenize import ONMTTokenizerTransform
17
56
18
- tokenizer = SentencePieceTransform (opt )
57
+ tokenizer = ONMTTokenizerTransform (opt )
58
+ else :
59
+ raise ValueError ("No valid tokenizer found" )
19
60
tokenizer .warm_up ()
20
- tokens = tokenizer ._tokenize (xx )
21
- print ("Done !" )
61
+ print ("warmup done" )
62
+ # joiner = tokenizer._tokenize("\n")
63
+ tokens = []
64
+ for x in xx :
65
+ tokens += tokenizer ._tokenize ([wikitext_detokenizer (x )])
66
+ # tokens += tokenizer._tokenize([x])
67
+ print ("Tokenization Done !" )
22
68
return tokens
23
69
24
70
@@ -38,7 +84,7 @@ def evaluate(opt):
38
84
set_random_seed (opt .seed , use_gpu (opt ))
39
85
40
86
# Tokenize the dataset.
41
- opt .src = "wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw"
87
+ opt .src = "eval_llm/WIKITEXT2/ wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw"
42
88
tokens = tokenize_dataset (opt , context_length = 512 )
43
89
44
90
# Build the translator (along with the model.
@@ -47,8 +93,8 @@ def evaluate(opt):
47
93
engine = InferenceEnginePY (engine_opt )
48
94
49
95
# Score the dataset.
50
- stride = 512
51
- max_seq_length = 4096
96
+ stride = 256
97
+ max_seq_length = 512
52
98
53
99
seq_len = len (tokens )
54
100
src = []
@@ -75,8 +121,7 @@ def evaluate(opt):
75
121
end_time = time .time ()
76
122
logger .info ("total run time %.2f" % (end_time - start_time ))
77
123
logger .info (
78
- "wikitext-2 perplexity with rolling likelihood and sliding window size 1000 and stride 512 %.2f" # noqa: E501
79
- % (ppl )
124
+ "wikitext-2 perplexity with rolling likelihood: %.2f" % (ppl ) # noqa: E501
80
125
)
81
126
82
127
0 commit comments