diff --git a/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py b/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py index 8a8afa4d5a..898f612688 100755 --- a/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py +++ b/eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py @@ -13,9 +13,10 @@ def tokenize_dataset(opt, context_length): # Clean and Concat the dataset x = open(opt.src, "r").readlines() - xx = [_x for _x in x if _x != ' \n'] + xx = [_x for _x in x if _x != " \n"] print(xx[:2]) from onmt.transforms.tokenize import SentencePieceTransform + tokenizer = SentencePieceTransform(opt) tokenizer.warm_up() tokens = tokenizer._tokenize(xx) @@ -48,13 +49,13 @@ def evaluate(opt): max_seq_length = 4096 max_seq_length = 1000 seq_len = len(tokens) - print('seq_len: ', seq_len) + print("seq_len: ", seq_len) score_results = [] nlls = [] src = [] for begin_loc in range(0, seq_len, stride): end_loc = min(begin_loc + max_seq_length - 1, seq_len) - src.append(' '.join(tokens[begin_loc:end_loc])) + src.append(" ".join(tokens[begin_loc:end_loc])) start_time = time.time() score_results = engine.score_list(src=src) @@ -65,7 +66,10 @@ def evaluate(opt): engine.terminate() end_time = time.time() logger.info("total run time %.2f" % (end_time - start_time)) - logger.info("wikitext-2 perplexity with rolling likelihood and sliding window size 1000 and stride 512 %.2f" % (ppl)) # noqa: E501 + logger.info( + "wikitext-2 perplexity with rolling likelihood and sliding window size 1000 and stride 512 %.2f" + % (ppl) + ) # noqa: E501 def _get_parser():