Skip to content

Commit 68c22af

Browse files
committed
fixed tokenization
1 parent de2a558 commit 68c22af

File tree

2 files changed

+8
-13
lines changed

2 files changed

+8
-13
lines changed

eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,15 @@
1010

1111
def tokenize_dataset(opt, context_length):
1212
print("Tokenization...")
13-
1413
# Clean and Concat the dataset
1514
x = open(opt.src, "r").readlines()
1615
xx = [_x for _x in x if _x != " \n"]
17-
print(xx[:2])
1816
from onmt.transforms.tokenize import SentencePieceTransform
1917

2018
tokenizer = SentencePieceTransform(opt)
2119
tokenizer.warm_up()
2220
tokens = tokenizer._tokenize(xx)
2321
print("Done !")
24-
print(len(tokens))
25-
print(tokens[:100])
2622
return tokens
2723

2824

@@ -46,23 +42,21 @@ def evaluate(opt):
4642

4743
# Score the dataset.
4844
stride = 512
49-
max_seq_length = 4096
5045
max_seq_length = 2048
46+
engine_opt.batch_type = "sents"
47+
engine_opt.batch_size = 1
5148
seq_len = len(tokens)
52-
print("seq_len: ", seq_len)
5349
score_results = []
5450
nlls = []
5551
src = []
5652
for begin_loc in range(0, seq_len, stride):
57-
end_loc = min(begin_loc + max_seq_length - 1, seq_len)
53+
end_loc = min(begin_loc + max_seq_length, seq_len)
5854
src.append(" ".join(tokens[begin_loc:end_loc]))
59-
6055
start_time = time.time()
6156
score_results = engine.score_list(src=src)
6257
nlls = [_score for (_score, _length) in score_results]
6358
lengths = [_length for (_score, _length) in score_results]
6459
ppl = np.exp(-np.sum(nlls) / np.sum(lengths))
65-
print(ppl)
6660
engine.terminate()
6761
end_time = time.time()
6862
logger.info("total run time %.2f" % (end_time - start_time))

onmt/translate/translator.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ def _score(self, infer_iter):
584584
self.with_scores = True
585585
scored_bucket = {}
586586
for batch, bucket_idx in infer_iter:
587-
batch_data = self.translate_batch(batch, attn_debug=False)
587+
batch_data = self.translate_batch(batch, attn_debug=False, scoring=True)
588588
batch_gold_scores = batch_data["gold_score"].cpu().numpy().tolist()
589589
batch_tgt_lengths = batch["tgtlen"].cpu().numpy().tolist()
590590
batch_inds_in_bucket = batch["ind_in_bucket"]
@@ -1001,8 +1001,9 @@ def _align_forward(self, batch, predictions):
10011001
"""
10021002
raise NotImplementedError
10031003

1004-
def translate_batch(self, batch, attn_debug):
1004+
def translate_batch(self, batch, attn_debug, scoring=False):
10051005
"""Translate a batch of sentences."""
1006+
max_length = 0 if scoring else self.max_length
10061007
with torch.no_grad():
10071008
if self.sample_from_topk != 0 or self.sample_from_topp != 0:
10081009
decode_strategy = GreedySearchLM(
@@ -1015,7 +1016,7 @@ def translate_batch(self, batch, attn_debug):
10151016
batch_size=len(batch["srclen"]),
10161017
global_scorer=self.global_scorer,
10171018
min_length=self.min_length,
1018-
max_length=self.max_length,
1019+
max_length=max_length,
10191020
block_ngram_repeat=self.block_ngram_repeat,
10201021
exclusion_tokens=self._exclusion_idxs,
10211022
return_attention=attn_debug or self.replace_unk,
@@ -1039,7 +1040,7 @@ def translate_batch(self, batch, attn_debug):
10391040
n_best=self.n_best,
10401041
global_scorer=self.global_scorer,
10411042
min_length=self.min_length,
1042-
max_length=self.max_length,
1043+
max_length=max_length,
10431044
return_attention=attn_debug or self.replace_unk,
10441045
block_ngram_repeat=self.block_ngram_repeat,
10451046
exclusion_tokens=self._exclusion_idxs,

0 commit comments

Comments
 (0)