Skip to content

Commit 3ef4683

Browse files
authored
Enhance SacreBLEU integration (#125)
1 parent 89cbfe2 commit 3ef4683

File tree

7 files changed

+109
-7
lines changed

7 files changed

+109
-7
lines changed

doc/metrics/bleu.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# BLEU
22
Our BLEU implementation is a wrapper around [SacreBLEU](https://github.com/mjpost/sacrebleu).
33
Although BLEU was intended to be a corpus-level metric, we have only implemented the sentence-level version.
4-
See `sacrebleu.sentence_bleu` for details.
4+
See `sacrebleu.BLEU` for details.
55
The metric is registered under the name `sent-bleu`.
66

77
## Setting Up

doc/metrics/chrf.md

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# chrF/chrF++
2+
We use the chrF/chrF++ implementation of [SacreBLEU](https://github.com/mjpost/sacrebleu).
3+
See `sacrebleu.CHRF` for details.
4+
The metric is registered under the name `chrf`.
5+
6+
## Setting Up
7+
No setup is required.

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ pytest
99
repro==0.1.1
1010
requests
1111
scipy>=1.5.2
12-
sacrebleu==1.5.1
12+
sacrebleu==2.0.0

sacrerouge/metrics/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from sacrerouge.metrics.blanc import Blanc
77
from sacrerouge.metrics.bleu import SentBleu
88
from sacrerouge.metrics.bleurt import Bleurt
9+
from sacrerouge.metrics.chrf import ChrF
910
from sacrerouge.metrics.meteor import Meteor
1011
from sacrerouge.metrics.moverscore import MoverScore
1112
from sacrerouge.metrics.pyramid_score import PyramidScore

sacrerouge/metrics/bleu.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from sacrebleu import sentence_bleu
2+
from sacrebleu import BLEU
33
from typing import List
44

55
from sacrerouge.common.util import flatten
@@ -16,11 +16,13 @@ class SentBleu(ReferenceBasedMetric):
1616
def __init__(self, **kwargs) -> None:
1717
"""
1818
Args:
19-
**kwargs: The kwargs that will be passed to `sacrebleu.sentence_bleu`. See
19+
**kwargs: The kwargs that will be passed to `sacrebleu.BLEU`. See
2020
the `sacrebleu` documentation for details.
2121
"""
2222
super().__init__()
23-
self.kwargs = kwargs
23+
if "effective_order" not in kwargs:
24+
kwargs["effective_order"] = True
25+
self.bleu = BLEU(**kwargs)
2426

2527
def score_multi_all(
2628
self,
@@ -34,6 +36,6 @@ def score_multi_all(
3436
scores_list.append([])
3537
for summary in summaries:
3638
summary = flatten(summary)
37-
score = sentence_bleu(summary, references, **self.kwargs)
39+
score = self.bleu.sentence_score(summary, references)
3840
scores_list[-1].append(MetricsDict({'sent-bleu': score.score}))
39-
return scores_list
41+
return scores_list

sacrerouge/metrics/chrf.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import logging
2+
from typing import List
3+
4+
import sacrebleu
5+
6+
from sacrerouge.common.util import flatten
7+
from sacrerouge.data import MetricsDict
8+
from sacrerouge.data.types import ReferenceType
9+
from sacrerouge.data.types import SummaryType
10+
from sacrerouge.metrics import Metric, ReferenceBasedMetric
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
@Metric.register('chrf')
16+
class ChrF(ReferenceBasedMetric):
17+
def __init__(self, **kwargs) -> None:
18+
"""
19+
Args:
20+
**kwargs: The kwargs that will be passed to `sacrebleu.CHRF`. See
21+
the `sacrebleu` documentation for details.
22+
"""
23+
super().__init__()
24+
self.chrf = sacrebleu.CHRF(**kwargs)
25+
26+
def score_multi_all(
27+
self,
28+
summaries_list: List[List[SummaryType]],
29+
references_list: List[List[ReferenceType]],
30+
**kwargs,
31+
) -> List[List[MetricsDict]]:
32+
scores_list = []
33+
for summaries, references in zip(summaries_list, references_list):
34+
references = [flatten(reference) for reference in references]
35+
scores_list.append([])
36+
for summary in summaries:
37+
summary = flatten(summary)
38+
score = self.chrf.sentence_score(summary, references)
39+
scores_list[-1].append(MetricsDict({'chrf': score.score}))
40+
return scores_list

sacrerouge/tests/metrics/chrf_test.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from sacrerouge.common.testing.metric_test_cases import ReferenceBasedMetricTestCase
2+
from sacrerouge.common.testing.util import sacrerouge_command_exists
3+
from sacrerouge.metrics.chrf import ChrF
4+
5+
6+
class TestChrF(ReferenceBasedMetricTestCase):
7+
def test_chrf(self):
8+
# This is a regression test, not necessarily a test for correctness
9+
metric = ChrF()
10+
expected_output = [
11+
{'chrf': 43.85388187268078},
12+
{'chrf': 52.60988684061151},
13+
{'chrf': 49.65855270200033},
14+
{'chrf': 41.43422564802467},
15+
{'chrf': 42.53477752245672},
16+
{'chrf': 41.91433502501302},
17+
{'chrf': 49.91729287293663},
18+
{'chrf': 49.63089847829706},
19+
{'chrf': 43.484464138558174},
20+
{'chrf': 56.073961048238665},
21+
{'chrf': 48.94445553665466},
22+
{'chrf': 46.85719853900437}
23+
]
24+
super().assert_expected_output(metric, expected_output)
25+
26+
def test_chrf_examples(self):
27+
chrf = ChrF()
28+
29+
hypothesis = 'The dog bit the man.'
30+
references = ['The dog bit the man.', 'The dog had bit the man.']
31+
expected = 100.0
32+
actual = chrf.score(hypothesis, references)['chrf']
33+
self.assertAlmostEqual(expected, actual)
34+
35+
hypothesis = 'It wasn\'t surprising.'
36+
references = ['It was not unexpected.', 'No one was surprised.']
37+
expected = 35.34635301425291
38+
actual = chrf.score(hypothesis, references)['chrf']
39+
self.assertAlmostEqual(expected, actual)
40+
41+
hypothesis = 'The man had just bitten him.'
42+
references = ['The man bit him first.', 'The man had bitten the dog.']
43+
expected = 51.87740799504779
44+
actual = chrf.score(hypothesis, references)['chrf']
45+
self.assertAlmostEqual(expected, actual)
46+
47+
def test_chrf_order_invariant(self):
48+
metric = ChrF()
49+
self.assert_order_invariant(metric)
50+
51+
def test_command_exists(self):
52+
assert sacrerouge_command_exists(['chrf'])

0 commit comments

Comments
 (0)