diff --git a/ignite/metrics/nlp/bleu.py b/ignite/metrics/nlp/bleu.py index ed3b14b4dc52..0ca724a2ddc3 100644 --- a/ignite/metrics/nlp/bleu.py +++ b/ignite/metrics/nlp/bleu.py @@ -2,6 +2,7 @@ from typing import Any, Callable, Sequence, Tuple, Union import torch +from torch import Tensor from ignite.exceptions import NotComputableError from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce @@ -236,12 +237,12 @@ def _corpus_bleu(self, references: Sequence[Sequence[Sequence[Any]]], candidates @reinit__is_reduced def reset(self) -> None: if self.average == "macro": - self._sum_of_bleu = torch.tensor(0.0, dtype=torch.double, device=self._device) + self._sum_of_bleu = torch.tensor(0.0, dtype=self._double_dtype, device=self._device) self._num_sentences = 0 if self.average == "micro": - self.p_numerators = torch.zeros(self.ngrams_order + 1) - self.p_denominators = torch.zeros(self.ngrams_order + 1) + self.p_numerators = torch.zeros(self.ngrams_order + 1, dtype=self._double_dtype) + self.p_denominators = torch.zeros(self.ngrams_order + 1, dtype=self._double_dtype) self.hyp_length_sum = 0 self.ref_length_sum = 0 @@ -278,8 +279,9 @@ def _compute_micro(self) -> float: ) return bleu_score - def compute(self) -> None: + def compute(self) -> Union[None, Tensor, float]: if self.average == "macro": return self._compute_macro() elif self.average == "micro": return self._compute_micro() + return None diff --git a/tests/ignite/metrics/nlp/test_bleu.py b/tests/ignite/metrics/nlp/test_bleu.py index 9de9c6de78c5..b191cd8ded6f 100644 --- a/tests/ignite/metrics/nlp/test_bleu.py +++ b/tests/ignite/metrics/nlp/test_bleu.py @@ -44,10 +44,10 @@ def test_wrong_inputs(): ) -def _test(candidates, references, average, smooth="no_smooth", smooth_nltk_fn=None, ngram_range=8): +def _test(candidates, references, average, smooth="no_smooth", smooth_nltk_fn=None, ngram_range=8, device="cpu"): for i in range(1, ngram_range): weights = tuple([1 / i] * i) - bleu = Bleu(ngram=i, average=average, smooth=smooth) + bleu = Bleu(ngram=i, average=average, smooth=smooth, device=device) if average == "macro": with warnings.catch_warnings(): @@ -64,51 +64,56 @@ def _test(candidates, references, average, smooth="no_smooth", smooth_nltk_fn=No assert pytest.approx(reference) == bleu._corpus_bleu(references, candidates) bleu.update((candidates, references)) - assert pytest.approx(reference) == bleu.compute() + computed = bleu.compute() + if isinstance(computed, torch.Tensor): + computed = computed.cpu().item() + assert pytest.approx(reference) == computed @pytest.mark.parametrize(*parametrize_args) -def test_macro_bleu(candidates, references): - _test(candidates, references, "macro") +def test_macro_bleu(candidates, references, available_device): + _test(candidates, references, "macro", device=available_device) @pytest.mark.parametrize(*parametrize_args) -def test_micro_bleu(candidates, references): - _test(candidates, references, "micro") +def test_micro_bleu(candidates, references, available_device): + _test(candidates, references, "micro", device=available_device) @pytest.mark.parametrize(*parametrize_args) -def test_macro_bleu_smooth1(candidates, references): - _test(candidates, references, "macro", "smooth1", SmoothingFunction().method1) +def test_macro_bleu_smooth1(candidates, references, available_device): + _test(candidates, references, "macro", "smooth1", SmoothingFunction().method1, device=available_device) @pytest.mark.parametrize(*parametrize_args) -def test_micro_bleu_smooth1(candidates, references): - _test(candidates, references, "micro", "smooth1", SmoothingFunction().method1) +def test_micro_bleu_smooth1(candidates, references, available_device): + _test(candidates, references, "micro", "smooth1", SmoothingFunction().method1, device=available_device) @pytest.mark.parametrize(*parametrize_args) -def test_macro_bleu_nltk_smooth2(candidates, references): - _test(candidates, references, "macro", "nltk_smooth2", SmoothingFunction().method2) +def test_macro_bleu_nltk_smooth2(candidates, references, available_device): + _test(candidates, references, "macro", "nltk_smooth2", SmoothingFunction().method2, device=available_device) @pytest.mark.parametrize(*parametrize_args) -def test_micro_bleu_nltk_smooth2(candidates, references): - _test(candidates, references, "micro", "nltk_smooth2", SmoothingFunction().method2) +def test_micro_bleu_nltk_smooth2(candidates, references, available_device): + _test(candidates, references, "micro", "nltk_smooth2", SmoothingFunction().method2, device=available_device) @pytest.mark.parametrize(*parametrize_args) -def test_macro_bleu_smooth2(candidates, references): - _test(candidates, references, "macro", "smooth2", SmoothingFunction().method2, 3) +def test_macro_bleu_smooth2(candidates, references, available_device): + _test(candidates, references, "macro", "smooth2", SmoothingFunction().method2, 3, available_device) @pytest.mark.parametrize(*parametrize_args) -def test_micro_bleu_smooth2(candidates, references): - _test(candidates, references, "micro", "smooth2", SmoothingFunction().method2, 3) +def test_micro_bleu_smooth2(candidates, references, available_device): + _test(candidates, references, "micro", "smooth2", SmoothingFunction().method2, 3, device=available_device) -def test_accumulation_macro_bleu(): - bleu = Bleu(ngram=4, smooth="smooth2") +def test_accumulation_macro_bleu(available_device): + bleu = Bleu(ngram=4, smooth="smooth2", device=available_device) + assert bleu._device == torch.device(available_device) + bleu.update(([corpus.cand_1], [corpus.references_1])) bleu.update(([corpus.cand_2a], [corpus.references_2])) bleu.update(([corpus.cand_2b], [corpus.references_2])) @@ -120,8 +125,10 @@ def test_accumulation_macro_bleu(): assert bleu.compute() == value / 4 -def test_accumulation_micro_bleu(): - bleu = Bleu(ngram=4, smooth="smooth2", average="micro") +def test_accumulation_micro_bleu(available_device): + bleu = Bleu(ngram=4, smooth="smooth2", average="micro", device=available_device) + assert bleu._device == torch.device(available_device) + bleu.update(([corpus.cand_1], [corpus.references_1])) bleu.update(([corpus.cand_2a], [corpus.references_2])) bleu.update(([corpus.cand_2b], [corpus.references_2])) @@ -133,8 +140,9 @@ def test_accumulation_micro_bleu(): assert bleu.compute() == value -def test_bleu_batch_macro(): - bleu = Bleu(ngram=4) +def test_bleu_batch_macro(available_device): + bleu = Bleu(ngram=4, device=available_device) + assert bleu._device == torch.device(available_device) # Batch size 3 hypotheses = [corpus.cand_1, corpus.cand_2a, corpus.cand_2b] @@ -148,7 +156,11 @@ def test_bleu_batch_macro(): + sentence_bleu(refs[1], hypotheses[1]) + sentence_bleu(refs[2], hypotheses[2]) ) / 3 - assert pytest.approx(bleu.compute()) == reference_bleu_score + computed = bleu.compute() + if isinstance(computed, torch.Tensor): + computed = computed.cpu().item() + + assert pytest.approx(computed) == reference_bleu_score value = 0 for _hypotheses, _refs in zip(hypotheses, refs): @@ -156,14 +168,17 @@ def test_bleu_batch_macro(): bleu.update(([_hypotheses], [_refs])) ref_1 = value / len(refs) - ref_2 = bleu.compute() + computed = bleu.compute() + if isinstance(computed, torch.Tensor): + computed = computed.cpu().item() assert pytest.approx(ref_1) == reference_bleu_score - assert pytest.approx(ref_2) == reference_bleu_score + assert pytest.approx(computed) == reference_bleu_score -def test_bleu_batch_micro(): - bleu = Bleu(ngram=4, average="micro") +def test_bleu_batch_micro(available_device): + bleu = Bleu(ngram=4, average="micro", device=available_device) + assert bleu._device == torch.device(available_device) # Batch size 3 hypotheses = [corpus.cand_1, corpus.cand_2a, corpus.cand_2b] @@ -187,8 +202,10 @@ def test_bleu_batch_micro(): (corpus.cand_1, corpus.references_1), ], ) -def test_n_gram_counter(candidates, references): - bleu = Bleu(ngram=4) +def test_n_gram_counter(candidates, references, available_device): + bleu = Bleu(ngram=4, device=available_device) + assert bleu._device == torch.device(available_device) + hyp_length, ref_length = bleu._n_gram_counter([references], [candidates], Counter(), Counter()) assert hyp_length == len(candidates) @@ -212,9 +229,9 @@ def _test_macro_distrib_integration(device): def update(_, i): return data[i + size * rank] - def _test(metric_device): + def _test(device): engine = Engine(update) - m = Bleu(ngram=4, smooth="smooth2") + m = Bleu(ngram=4, smooth="smooth2", device=device) m.attach(engine, "bleu") engine.run(data=list(range(size)), max_epochs=1) @@ -256,7 +273,7 @@ def update(_, i): def _test(metric_device): engine = Engine(update) - m = Bleu(ngram=4, smooth="smooth2", average="micro") + m = Bleu(ngram=4, smooth="smooth2", average="micro", device=metric_device) m.attach(engine, "bleu") engine.run(data=list(range(size)), max_epochs=1) diff --git a/tests/ignite/metrics/nlp/test_rouge.py b/tests/ignite/metrics/nlp/test_rouge.py index 5d8562866c83..5dbf4c9bde8f 100644 --- a/tests/ignite/metrics/nlp/test_rouge.py +++ b/tests/ignite/metrics/nlp/test_rouge.py @@ -84,9 +84,9 @@ def test_wrong_inputs(): (2, "abcdef", "zbdfz", (0, 0)), ], ) -def test_rouge_n_alpha(ngram, candidate, reference, expected): +def test_rouge_n_alpha(ngram, candidate, reference, expected, available_device): for alpha in [0, 1, 0.3, 0.5, 0.8]: - rouge = RougeN(ngram=ngram, alpha=alpha) + rouge = RougeN(ngram=ngram, alpha=alpha, device=available_device) rouge.update(([candidate], [[reference]])) results = rouge.compute() assert results[f"Rouge-{ngram}-P"] == expected[0] @@ -101,7 +101,7 @@ def test_rouge_n_alpha(ngram, candidate, reference, expected): @pytest.mark.parametrize( "candidates, references", [corpus.sample_1, corpus.sample_2, corpus.sample_3, corpus.sample_4, corpus.sample_5] ) -def test_rouge_metrics(candidates, references): +def test_rouge_metrics(candidates, references, available_device): for multiref in ["average", "best"]: # PERL 1.5.5 reference apply_avg = multiref == "average" @@ -123,7 +123,8 @@ def test_rouge_metrics(candidates, references): lower_split_candidates = [candidate.lower().split() for candidate in candidates] - m = Rouge(variants=[1, 2, 4, "L"], multiref=multiref, alpha=0.5) + m = Rouge(variants=[1, 2, 4, "L"], multiref=multiref, alpha=0.5, device=available_device) + assert m._device == torch.device(available_device) m.update((lower_split_candidates, lower_split_references)) results = m.compute()