Skip to content

Commit 5deb20e

Browse files
authored
various fixes (#2563)
1 parent a02330a commit 5deb20e

File tree

5 files changed

+138
-26
lines changed

5 files changed

+138
-26
lines changed

eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py

+56-11
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,70 @@
11
import copy
22
import numpy as np
33
import time
4+
import re
45
from onmt.inference_engine import InferenceEnginePY
56
import onmt.opts as opts
67
from onmt.utils.logging import init_logger
78
from onmt.utils.parse import ArgumentParser
89
from onmt.utils.misc import use_gpu, set_random_seed
910

1011

12+
def wikitext_detokenizer(line):
13+
string = line
14+
# contractions
15+
string = string.replace("s '", "s'")
16+
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
17+
# number separators
18+
string = string.replace(" @-@ ", "-")
19+
string = string.replace(" @,@ ", ",")
20+
string = string.replace(" @.@ ", ".")
21+
# punctuation
22+
string = string.replace(" : ", ": ")
23+
string = string.replace(" ; ", "; ")
24+
string = string.replace(" . ", ". ")
25+
string = string.replace(" ! ", "! ")
26+
string = string.replace(" ? ", "? ")
27+
string = string.replace(" , ", ", ")
28+
# double brackets
29+
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
30+
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
31+
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
32+
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
33+
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
34+
# miscellaneous
35+
string = string.replace("= = = =", "====")
36+
string = string.replace("= = =", "===")
37+
string = string.replace("= =", "==")
38+
string = string.replace(" " + chr(176) + " ", chr(176))
39+
string = string.replace(" \n", "\n")
40+
string = string.replace("\n ", "\n")
41+
string = string.replace(" N ", " 1 ")
42+
string = string.replace(" 's", "'s")
43+
return string
44+
45+
1146
def tokenize_dataset(opt, context_length):
1247
print("Tokenization...")
1348
# Clean and Concat the dataset
14-
x = open(opt.src, "r").readlines()
15-
xx = [_x for _x in x if _x != " \n"]
16-
from onmt.transforms.tokenize import SentencePieceTransform
49+
xx = open(opt.src, "r").readlines()
50+
if "sentencepiece" in opt.transforms:
51+
from onmt.transforms.tokenize import SentencePieceTransform
52+
53+
tokenizer = SentencePieceTransform(opt)
54+
elif "onmt_tokenize" in opt.transforms:
55+
from onmt.transforms.tokenize import ONMTTokenizerTransform
1756

18-
tokenizer = SentencePieceTransform(opt)
57+
tokenizer = ONMTTokenizerTransform(opt)
58+
else:
59+
raise ValueError("No valid tokenizer found")
1960
tokenizer.warm_up()
20-
tokens = tokenizer._tokenize(xx)
21-
print("Done !")
61+
print("warmup done")
62+
# joiner = tokenizer._tokenize("\n")
63+
tokens = []
64+
for x in xx:
65+
tokens += tokenizer._tokenize([wikitext_detokenizer(x)])
66+
# tokens += tokenizer._tokenize([x])
67+
print("Tokenization Done !")
2268
return tokens
2369

2470

@@ -38,7 +84,7 @@ def evaluate(opt):
3884
set_random_seed(opt.seed, use_gpu(opt))
3985

4086
# Tokenize the dataset.
41-
opt.src = "wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw"
87+
opt.src = "eval_llm/WIKITEXT2/wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw"
4288
tokens = tokenize_dataset(opt, context_length=512)
4389

4490
# Build the translator (along with the model.
@@ -47,8 +93,8 @@ def evaluate(opt):
4793
engine = InferenceEnginePY(engine_opt)
4894

4995
# Score the dataset.
50-
stride = 512
51-
max_seq_length = 4096
96+
stride = 256
97+
max_seq_length = 512
5298

5399
seq_len = len(tokens)
54100
src = []
@@ -75,8 +121,7 @@ def evaluate(opt):
75121
end_time = time.time()
76122
logger.info("total run time %.2f" % (end_time - start_time))
77123
logger.info(
78-
"wikitext-2 perplexity with rolling likelihood and sliding window size 1000 and stride 512 %.2f" # noqa: E501
79-
% (ppl)
124+
"wikitext-2 perplexity with rolling likelihood: %.2f" % (ppl) # noqa: E501
80125
)
81126

82127

onmt/inputters/text_corpus.py

+72-8
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,63 @@ def exfile_open(filename, *args, **kwargs):
3838
_file.close()
3939

4040

41+
class BlockwiseCorpus(object):
42+
"""A corpus class for reading a single file block by block."""
43+
44+
def __init__(self, name, file_path, block_size=4096):
45+
"""Initialize file path and block size."""
46+
self.id = name
47+
self.file_path = file_path
48+
self.block_size = block_size
49+
50+
def load(self, offset=0, stride=1):
51+
"""
52+
Load file and iterate by blocks.
53+
`offset` and `stride` allow iterating only on every
54+
`stride` block, starting from `offset`.
55+
"""
56+
57+
def make_ex(block_content):
58+
example = {
59+
"src": block_content,
60+
"tgt": block_content,
61+
"src_original": block_content,
62+
"tgt_original": block_content,
63+
}
64+
return example
65+
66+
with open(self.file_path, mode="r", encoding="utf-8") as file:
67+
block_content = ""
68+
block_index = 0
69+
70+
while True:
71+
chunk = file.read(self.block_size)
72+
if not chunk:
73+
break
74+
75+
if (block_index // stride) % stride == offset:
76+
block_content += chunk
77+
78+
if len(chunk) < self.block_size:
79+
# Reached end of file
80+
yield make_ex(block_content)
81+
break
82+
83+
if len(block_content) >= self.block_size:
84+
yield make_ex(block_content)
85+
block_content = ""
86+
block_index += 1
87+
88+
def __str__(self):
89+
cls_name = type(self).__name__
90+
return (
91+
f"{cls_name}({self.id}, {self.file_path}, {self.file_path}"
92+
f"align={None}, "
93+
f"n_src_feats={0}, "
94+
f'src_feats_defaults="{None}")'
95+
)
96+
97+
4198
class ParallelCorpus(object):
4299
"""A parallel corpus file pair that can be loaded to iterate."""
43100

@@ -117,14 +174,21 @@ def get_corpora(opts, task=CorpusTask.TRAIN, src=None, tgt=None, align=None):
117174
if task == CorpusTask.TRAIN:
118175
for corpus_id, corpus_dict in opts.data.items():
119176
if corpus_id != CorpusName.VALID:
120-
corpora_dict[corpus_id] = ParallelCorpus(
121-
corpus_id,
122-
corpus_dict["path_src"],
123-
corpus_dict["path_tgt"],
124-
corpus_dict["path_align"],
125-
n_src_feats=opts.n_src_feats,
126-
src_feats_defaults=opts.src_feats_defaults,
127-
)
177+
if corpus_dict.get("path_txt", None) is None:
178+
corpora_dict[corpus_id] = ParallelCorpus(
179+
corpus_id,
180+
corpus_dict["path_src"],
181+
corpus_dict["path_tgt"],
182+
corpus_dict["path_align"],
183+
n_src_feats=opts.n_src_feats,
184+
src_feats_defaults=opts.src_feats_defaults,
185+
)
186+
else:
187+
corpora_dict[corpus_id] = BlockwiseCorpus(
188+
corpus_id,
189+
corpus_dict["path_txt"],
190+
block_size=8192, # number of characters
191+
)
128192
elif task == CorpusTask.VALID:
129193
if CorpusName.VALID in opts.data.keys():
130194
corpora_dict[CorpusName.VALID] = ParallelCorpus(

onmt/inputters/text_utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ def numericalize(vocabs, example):
149149
for fv, feat in zip(vocabs["src_feats"], example["src"]["feats"]):
150150
numeric_feats.append(fv(feat.split(" ")))
151151
numeric["src"]["feats"] = numeric_feats
152-
153152
return numeric
154153

155154

onmt/opts.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -872,8 +872,7 @@ def model_opts(parser):
872872
group.add(
873873
"--rotary_interleave",
874874
"-rotary_interleave",
875-
type=bool,
876-
default=True,
875+
action="store_true",
877876
help="Interleave the head dimensions when rotary"
878877
" embeddings are applied."
879878
" Otherwise the head dimensions are sliced in half."

onmt/utils/parse.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@ def _validate_data(cls, opt):
3939
# Check path
4040
path_src = corpus.get("path_src", None)
4141
path_tgt = corpus.get("path_tgt", None)
42-
if path_src is None:
42+
path_txt = corpus.get("path_txt", None)
43+
if path_src is None and path_txt is None:
4344
raise ValueError(
44-
f"Corpus {cname} src path is required."
45+
f"Corpus {cname} src/txt path is required."
4546
"tgt path is also required for non language"
4647
" modeling tasks."
4748
)
@@ -57,8 +58,12 @@ def _validate_data(cls, opt):
5758
corpus["path_tgt"] = path_src
5859
corpora[cname] = corpus
5960
path_tgt = path_src
60-
cls._validate_file(path_src, info=f"{cname}/path_src")
61-
cls._validate_file(path_tgt, info=f"{cname}/path_tgt")
61+
if path_src is not None:
62+
cls._validate_file(path_src, info=f"{cname}/path_src")
63+
if path_txt is not None:
64+
cls._validate_file(path_txt, info=f"{cname}/path_txt")
65+
if path_tgt is not None:
66+
cls._validate_file(path_tgt, info=f"{cname}/path_tgt")
6267
path_align = corpus.get("path_align", None)
6368
if path_align is None:
6469
if hasattr(opt, "lambda_align") and opt.lambda_align > 0.0:

0 commit comments

Comments
 (0)