-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathexperiment_utils.py
51 lines (42 loc) · 1.32 KB
/
experiment_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from dataclasses import dataclass
from pathlib import Path
from typing import List
import sacrebleu
SEEDS = [
553589,
456178,
817304,
6277,
792418,
707983,
249859,
272618,
760402,
472974,
]
@dataclass
class Testset:
testset: str
language_pair: str
source_sentences: List[str]
references: List[str]
@property
def src_lang(self):
return self.language_pair.split("-")[0]
@property
def tgt_lang(self):
return self.language_pair.split("-")[1]
@classmethod
def from_wmt(cls, wmt: str, language_pair: str, limit_segments: int = None):
assert wmt in {"wmt21", "wmt22"}
src_path = sacrebleu.get_source_file(wmt, language_pair)
ref_path = sacrebleu.get_reference_files(wmt, language_pair)[0]
source_sequences = Path(src_path).read_text().splitlines()
references = Path(ref_path).read_text().splitlines()
assert len(source_sequences) == len(references)
if limit_segments is not None:
source_sequences = source_sequences[:limit_segments]
references = references[:limit_segments]
return cls(testset=wmt, language_pair=language_pair, source_sentences=source_sequences, references=references, )
def __str__(self):
return f"{self.testset}.{self.language_pair}"