Skip to content

adds available device to nlp tests #3335 #3385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions ignite/metrics/nlp/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Callable, Sequence, Tuple, Union

import torch
from torch import Tensor

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
Expand Down Expand Up @@ -236,12 +237,12 @@ def _corpus_bleu(self, references: Sequence[Sequence[Sequence[Any]]], candidates
@reinit__is_reduced
def reset(self) -> None:
if self.average == "macro":
self._sum_of_bleu = torch.tensor(0.0, dtype=torch.double, device=self._device)
self._sum_of_bleu = torch.tensor(0.0, dtype=self._double_dtype, device=self._device)
self._num_sentences = 0

if self.average == "micro":
self.p_numerators = torch.zeros(self.ngrams_order + 1)
self.p_denominators = torch.zeros(self.ngrams_order + 1)
self.p_numerators = torch.zeros(self.ngrams_order + 1, dtype=self._double_dtype)
self.p_denominators = torch.zeros(self.ngrams_order + 1, dtype=self._double_dtype)
self.hyp_length_sum = 0
self.ref_length_sum = 0

Expand Down Expand Up @@ -278,8 +279,9 @@ def _compute_micro(self) -> float:
)
return bleu_score

def compute(self) -> None:
def compute(self) -> Union[None, Tensor, float]:
if self.average == "macro":
return self._compute_macro()
elif self.average == "micro":
return self._compute_micro()
return None
87 changes: 52 additions & 35 deletions tests/ignite/metrics/nlp/test_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def test_wrong_inputs():
)


def _test(candidates, references, average, smooth="no_smooth", smooth_nltk_fn=None, ngram_range=8):
def _test(candidates, references, average, smooth="no_smooth", smooth_nltk_fn=None, ngram_range=8, device="cpu"):
for i in range(1, ngram_range):
weights = tuple([1 / i] * i)
bleu = Bleu(ngram=i, average=average, smooth=smooth)
bleu = Bleu(ngram=i, average=average, smooth=smooth, device=device)

if average == "macro":
with warnings.catch_warnings():
Expand All @@ -64,51 +64,56 @@ def _test(candidates, references, average, smooth="no_smooth", smooth_nltk_fn=No
assert pytest.approx(reference) == bleu._corpus_bleu(references, candidates)

bleu.update((candidates, references))
assert pytest.approx(reference) == bleu.compute()
computed = bleu.compute()
if isinstance(computed, torch.Tensor):
computed = computed.cpu().item()
assert pytest.approx(reference) == computed


@pytest.mark.parametrize(*parametrize_args)
def test_macro_bleu(candidates, references):
_test(candidates, references, "macro")
def test_macro_bleu(candidates, references, available_device):
_test(candidates, references, "macro", device=available_device)


@pytest.mark.parametrize(*parametrize_args)
def test_micro_bleu(candidates, references):
_test(candidates, references, "micro")
def test_micro_bleu(candidates, references, available_device):
_test(candidates, references, "micro", device=available_device)


@pytest.mark.parametrize(*parametrize_args)
def test_macro_bleu_smooth1(candidates, references):
_test(candidates, references, "macro", "smooth1", SmoothingFunction().method1)
def test_macro_bleu_smooth1(candidates, references, available_device):
_test(candidates, references, "macro", "smooth1", SmoothingFunction().method1, device=available_device)


@pytest.mark.parametrize(*parametrize_args)
def test_micro_bleu_smooth1(candidates, references):
_test(candidates, references, "micro", "smooth1", SmoothingFunction().method1)
def test_micro_bleu_smooth1(candidates, references, available_device):
_test(candidates, references, "micro", "smooth1", SmoothingFunction().method1, device=available_device)


@pytest.mark.parametrize(*parametrize_args)
def test_macro_bleu_nltk_smooth2(candidates, references):
_test(candidates, references, "macro", "nltk_smooth2", SmoothingFunction().method2)
def test_macro_bleu_nltk_smooth2(candidates, references, available_device):
_test(candidates, references, "macro", "nltk_smooth2", SmoothingFunction().method2, device=available_device)


@pytest.mark.parametrize(*parametrize_args)
def test_micro_bleu_nltk_smooth2(candidates, references):
_test(candidates, references, "micro", "nltk_smooth2", SmoothingFunction().method2)
def test_micro_bleu_nltk_smooth2(candidates, references, available_device):
_test(candidates, references, "micro", "nltk_smooth2", SmoothingFunction().method2, device=available_device)


@pytest.mark.parametrize(*parametrize_args)
def test_macro_bleu_smooth2(candidates, references):
_test(candidates, references, "macro", "smooth2", SmoothingFunction().method2, 3)
def test_macro_bleu_smooth2(candidates, references, available_device):
_test(candidates, references, "macro", "smooth2", SmoothingFunction().method2, 3, available_device)


@pytest.mark.parametrize(*parametrize_args)
def test_micro_bleu_smooth2(candidates, references):
_test(candidates, references, "micro", "smooth2", SmoothingFunction().method2, 3)
def test_micro_bleu_smooth2(candidates, references, available_device):
_test(candidates, references, "micro", "smooth2", SmoothingFunction().method2, 3, device=available_device)


def test_accumulation_macro_bleu():
bleu = Bleu(ngram=4, smooth="smooth2")
def test_accumulation_macro_bleu(available_device):
bleu = Bleu(ngram=4, smooth="smooth2", device=available_device)
assert bleu._device == torch.device(available_device)

bleu.update(([corpus.cand_1], [corpus.references_1]))
bleu.update(([corpus.cand_2a], [corpus.references_2]))
bleu.update(([corpus.cand_2b], [corpus.references_2]))
Expand All @@ -120,8 +125,10 @@ def test_accumulation_macro_bleu():
assert bleu.compute() == value / 4


def test_accumulation_micro_bleu():
bleu = Bleu(ngram=4, smooth="smooth2", average="micro")
def test_accumulation_micro_bleu(available_device):
bleu = Bleu(ngram=4, smooth="smooth2", average="micro", device=available_device)
assert bleu._device == torch.device(available_device)

bleu.update(([corpus.cand_1], [corpus.references_1]))
bleu.update(([corpus.cand_2a], [corpus.references_2]))
bleu.update(([corpus.cand_2b], [corpus.references_2]))
Expand All @@ -133,8 +140,9 @@ def test_accumulation_micro_bleu():
assert bleu.compute() == value


def test_bleu_batch_macro():
bleu = Bleu(ngram=4)
def test_bleu_batch_macro(available_device):
bleu = Bleu(ngram=4, device=available_device)
assert bleu._device == torch.device(available_device)

# Batch size 3
hypotheses = [corpus.cand_1, corpus.cand_2a, corpus.cand_2b]
Expand All @@ -148,22 +156,29 @@ def test_bleu_batch_macro():
+ sentence_bleu(refs[1], hypotheses[1])
+ sentence_bleu(refs[2], hypotheses[2])
) / 3
assert pytest.approx(bleu.compute()) == reference_bleu_score
computed = bleu.compute()
if isinstance(computed, torch.Tensor):
computed = computed.cpu().item()

assert pytest.approx(computed) == reference_bleu_score

value = 0
for _hypotheses, _refs in zip(hypotheses, refs):
value += bleu._sentence_bleu(_refs, _hypotheses)
bleu.update(([_hypotheses], [_refs]))

ref_1 = value / len(refs)
ref_2 = bleu.compute()
computed = bleu.compute()
if isinstance(computed, torch.Tensor):
computed = computed.cpu().item()

assert pytest.approx(ref_1) == reference_bleu_score
assert pytest.approx(ref_2) == reference_bleu_score
assert pytest.approx(computed) == reference_bleu_score


def test_bleu_batch_micro():
bleu = Bleu(ngram=4, average="micro")
def test_bleu_batch_micro(available_device):
bleu = Bleu(ngram=4, average="micro", device=available_device)
assert bleu._device == torch.device(available_device)

# Batch size 3
hypotheses = [corpus.cand_1, corpus.cand_2a, corpus.cand_2b]
Expand All @@ -187,8 +202,10 @@ def test_bleu_batch_micro():
(corpus.cand_1, corpus.references_1),
],
)
def test_n_gram_counter(candidates, references):
bleu = Bleu(ngram=4)
def test_n_gram_counter(candidates, references, available_device):
bleu = Bleu(ngram=4, device=available_device)
assert bleu._device == torch.device(available_device)

hyp_length, ref_length = bleu._n_gram_counter([references], [candidates], Counter(), Counter())
assert hyp_length == len(candidates)

Expand All @@ -212,9 +229,9 @@ def _test_macro_distrib_integration(device):
def update(_, i):
return data[i + size * rank]

def _test(metric_device):
def _test(device):
engine = Engine(update)
m = Bleu(ngram=4, smooth="smooth2")
m = Bleu(ngram=4, smooth="smooth2", device=device)
m.attach(engine, "bleu")

engine.run(data=list(range(size)), max_epochs=1)
Expand Down Expand Up @@ -256,7 +273,7 @@ def update(_, i):

def _test(metric_device):
engine = Engine(update)
m = Bleu(ngram=4, smooth="smooth2", average="micro")
m = Bleu(ngram=4, smooth="smooth2", average="micro", device=metric_device)
m.attach(engine, "bleu")

engine.run(data=list(range(size)), max_epochs=1)
Expand Down
9 changes: 5 additions & 4 deletions tests/ignite/metrics/nlp/test_rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def test_wrong_inputs():
(2, "abcdef", "zbdfz", (0, 0)),
],
)
def test_rouge_n_alpha(ngram, candidate, reference, expected):
def test_rouge_n_alpha(ngram, candidate, reference, expected, available_device):
for alpha in [0, 1, 0.3, 0.5, 0.8]:
rouge = RougeN(ngram=ngram, alpha=alpha)
rouge = RougeN(ngram=ngram, alpha=alpha, device=available_device)
rouge.update(([candidate], [[reference]]))
results = rouge.compute()
assert results[f"Rouge-{ngram}-P"] == expected[0]
Expand All @@ -101,7 +101,7 @@ def test_rouge_n_alpha(ngram, candidate, reference, expected):
@pytest.mark.parametrize(
"candidates, references", [corpus.sample_1, corpus.sample_2, corpus.sample_3, corpus.sample_4, corpus.sample_5]
)
def test_rouge_metrics(candidates, references):
def test_rouge_metrics(candidates, references, available_device):
for multiref in ["average", "best"]:
# PERL 1.5.5 reference
apply_avg = multiref == "average"
Expand All @@ -123,7 +123,8 @@ def test_rouge_metrics(candidates, references):

lower_split_candidates = [candidate.lower().split() for candidate in candidates]

m = Rouge(variants=[1, 2, 4, "L"], multiref=multiref, alpha=0.5)
m = Rouge(variants=[1, 2, 4, "L"], multiref=multiref, alpha=0.5, device=available_device)
assert m._device == torch.device(available_device)
m.update((lower_split_candidates, lower_split_references))
results = m.compute()

Expand Down
Loading