Skip to content

Add NDCG metrics #3346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ignite.metrics.precision import Precision
from ignite.metrics.precision_recall_curve import PrecisionRecallCurve
from ignite.metrics.psnr import PSNR
from ignite.metrics.recsys.ndcg import NDCG
from ignite.metrics.recall import Recall
from ignite.metrics.roc_auc import ROC_AUC, RocCurve
from ignite.metrics.root_mean_squared_error import RootMeanSquaredError
Expand Down Expand Up @@ -88,6 +89,7 @@
"Rouge",
"RougeN",
"RougeL",
"NDCG",
"regression",
"clustering",
"AveragePrecision",
Expand Down
5 changes: 5 additions & 0 deletions ignite/metrics/recsys/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ignite.metrics.recsys.ndcg import NDCG

__all__ = [
"NDCG",
]
142 changes: 142 additions & 0 deletions ignite/metrics/recsys/ndcg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from typing import Callable, Optional, Sequence, Union

import torch

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["NDCG"]


class NDCG(Metric):
r"""Computes ndcg
`Normalized DCG(DCG) <https://en.wikipedia.org/wiki/Discounted_cumulative_gain>`_.
.. math::
\text{nDCG}_\text{p} = \frac{\text{DCG}_p}{\text{nDCG}_p}
where :math: \text{DCG}_\text{p} = \sum_{i = 1}^p \frac{2^{rel_i} - 1}{\log_2{(i + 1)}}
:math: \text{IDCG}_\text{p} = \sum_{i = 1}^{|REL_p|} \frac{2^{rel_i} - 1}{\log_2{(i + 1)}}
:math: \text{$rel_i \in \{0, 1\}$ : graded relevance of the result at position $i$}
- ``update`` must receive output of the form ``(y_pred, y)``.
Args:
output_transform: A callable that is used to transform the Engine's
process_function's output into the form expected by the metric.
device: specifies which device updates are accumulated on.
Setting the metric's device to be the same as your update arguments ensures
the update method is non-blocking. By default, CPU.
k: Only consider the highest k scores in the ranking. If None, use all outputs.
log_base: Base of logarithm used in computation
exponential: If True, computes exponential gain
ignore_ties: Assume that there are no ties in y_score (which is likely to be the
case if y_score is continuous) for efficiency gains.
Examples:
"""

def __init__(
self,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
k: Optional[int] = None,
log_base: Union[int, float] = 2,
exponential: bool = False,
ignore_ties: bool = False,
):
if log_base == 1 or log_base <= 0:
raise ValueError(f"Argument log_base should positive and not equal one,but got {log_base}")
self.log_base = log_base
self.k = k
self.exponential = exponential
self.ignore_ties = ignore_ties
super(NDCG, self).__init__(output_transform=output_transform, device=device)

@reinit__is_reduced
def reset(self) -> None:
self.num_examples = 0
self.ndcg = torch.tensor(0.0, device=self._device)

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y_true = output[0].detach(), output[1].detach()

y_pred = y_pred.to(torch.float32).to(self._device)
y_true = y_true.to(torch.float32).to(self._device)

if self.exponential:
y_true = 2**y_true - 1

gain = _ndcg_sample_scores(y_pred, y_true, k=self.k, log_base=self.log_base, ignore_ties=self.ignore_ties)
self.ndcg += torch.sum(gain)
self.num_examples += y_pred.shape[0]

@sync_all_reduce("ndcg", "num_examples")
def compute(self) -> float:
if self.num_examples == 0:
raise NotComputableError("NGCD must have at least one example before it can be computed.")

return (self.ndcg / self.num_examples).item()


def _tie_averaged_dcg_batched(y_pred_batch, y_true_batch, discount_cumsum, device):
batch_size = y_pred_batch.shape[0]
results = torch.zeros(batch_size, device=device)

for i in range(batch_size):
y_pred = y_pred_batch[i]
y_true = y_true_batch[i]

_, inv, counts = torch.unique(-y_pred, return_inverse=True, return_counts=True)
ranked = torch.zeros(counts.shape[0], device=device)
ranked.index_add_(0, inv, y_true)
ranked /= counts
groups = torch.cumsum(counts, dim=-1) - 1
discount_sums = torch.zeros(counts.shape[0], device=device)
discount_sums[0] = discount_cumsum[groups[0]]
if counts.shape[0] > 1:
discount_sums[1:] = torch.diff(discount_cumsum[groups])

results[i] = torch.sum(torch.mul(ranked, discount_sums))

return results


def _dcg_sample_scores(
y_pred: torch.Tensor,
y_true: torch.Tensor,
k: Optional[int] = None,
log_base: Union[int, float] = 2,
ignore_ties: bool = False,
device: Union[str, torch.device] = torch.device("cpu"),
) -> torch.Tensor:

discount = torch.log(torch.tensor(log_base)) / torch.log(torch.arange(y_true.shape[1]) + 2)
discount = discount.to(device)

if k is not None:
discount[k:] = 0.0

if ignore_ties:
ranking = torch.argsort(y_pred, descending=True)
ranked = y_true[torch.arange(ranking.shape[0]).reshape(-1, 1), ranking].to(device)
discounted_gains = torch.mm(ranked, discount.reshape(-1, 1))
else:
discount_cumsum = torch.cumsum(discount, dim=-1)
discounted_gains = _tie_averaged_dcg_batched(y_pred, y_true, discount_cumsum, device)

return discounted_gains


def _ndcg_sample_scores(
y_pred: torch.Tensor,
y_true: torch.Tensor,
k: Optional[int] = None,
log_base: Union[int, float] = 2,
ignore_ties: bool = False,
) -> torch.Tensor:

device = y_true.device
gain = _dcg_sample_scores(y_pred, y_true, k=k, log_base=log_base, ignore_ties=ignore_ties, device=device)
if not ignore_ties:
gain = gain.unsqueeze(dim=-1)
normalizing_gain = _dcg_sample_scores(y_true, y_true, k=k, log_base=log_base, ignore_ties=True, device=device)
all_relevant = normalizing_gain != 0
normalized_gain = gain[all_relevant] / normalizing_gain[all_relevant]
return normalized_gain
168 changes: 168 additions & 0 deletions tests/ignite/metrics/test_ndcg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import numpy as np
import pytest
import torch
from sklearn.metrics import ndcg_score
from sklearn.metrics._ranking import _dcg_sample_scores

import ignite.distributed as idist

from ignite.exceptions import NotComputableError
from ignite.metrics.recsys.ndcg import NDCG


@pytest.fixture(params=[item for item in range(6)])
def test_case(request):
return [
(torch.tensor([[3.7, 4.8, 3.9, 4.3, 4.9]]), torch.tensor([[2.9, 5.6, 3.8, 7.9, 6.2]])),
(
torch.tensor([[3.7, 3.7, 3.7, 3.7, 3.7], [3.7, 3.7, 3.7, 3.7, 3.9]]),
torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [1.0, 2.0, 3.0, 4.0, 5.0]]),
),
][request.param % 2]


@pytest.mark.parametrize("k", [None, 2, 3])
@pytest.mark.parametrize("exponential", [True, False])
@pytest.mark.parametrize("ignore_ties, replacement", [(True, False), (False, True), (False, False)])
def test_output(available_device, test_case, k, exponential, ignore_ties, replacement):
device = available_device
y_pred_distribution, y = test_case

y_pred = torch.multinomial(y_pred_distribution, 5, replacement=replacement)

y_pred = y_pred.to(device)
y = y.to(device)

ndcg = NDCG(k=k, device=device, exponential=exponential, ignore_ties=ignore_ties)
ndcg.update([y_pred, y])
result_ignite = ndcg.compute()

if exponential:
y = 2**y - 1

result_sklearn = ndcg_score(y.cpu().numpy(), y_pred.cpu().numpy(), k=k, ignore_ties=ignore_ties)

np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6)


def test_reset():
y = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
y_pred = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]])
ndcg = NDCG()
ndcg.update([y_pred, y])
ndcg.reset()

with pytest.raises(NotComputableError, match=r"NGCD must have at least one example before it can be computed."):
ndcg.compute()


def _ndcg_sample_scores(y, y_score, k=None, ignore_ties=False):
gain = _dcg_sample_scores(y, y_score, k, ignore_ties=ignore_ties)
normalizing_gain = _dcg_sample_scores(y, y, k, ignore_ties=True)
all_irrelevant = normalizing_gain == 0
gain[all_irrelevant] = 0
gain[~all_irrelevant] /= normalizing_gain[~all_irrelevant]
return gain


@pytest.mark.parametrize("log_base", [2, 3, 10])
def test_log_base(log_base):
def ndcg_score_with_log_base(y, y_score, *, k=None, sample_weight=None, ignore_ties=False, log_base=2):
gain = _ndcg_sample_scores(y, y_score, k=k, ignore_ties=ignore_ties)
return np.average(gain, weights=sample_weight)

y = torch.tensor([[3.7, 4.8, 3.9, 4.3, 4.9]])
y_pred = torch.tensor([[2.9, 5.6, 3.8, 7.9, 6.2]])

ndcg = NDCG(log_base=log_base)
ndcg.update([y_pred, y])

result_ignite = ndcg.compute()
result_sklearn = ndcg_score_with_log_base(y.numpy(), y_pred.numpy(), log_base=log_base)

np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6)


def test_update(test_case):
y_pred, y = test_case

y_pred = y_pred
y = y

y1_pred = torch.multinomial(y_pred, 5, replacement=True)
y1_true = torch.multinomial(y, 5, replacement=True)

y2_pred = torch.multinomial(y_pred, 5, replacement=True)
y2_true = torch.multinomial(y, 5, replacement=True)

y_pred_combined = torch.cat((y1_pred, y2_pred))
y_combined = torch.cat((y1_true, y2_true))

ndcg = NDCG()

ndcg.update([y1_pred, y1_true])
ndcg.update([y2_pred, y2_true])

result_ignite = ndcg.compute()

result_sklearn = ndcg_score(y_combined.numpy(), y_pred_combined.numpy())

np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6)


@pytest.mark.parametrize("metric_device", ["cpu", "process_device"])
@pytest.mark.parametrize("num_epochs", [1, 2])
def test_distrib_integration(distributed, num_epochs, metric_device):
from ignite.engine import Engine

rank = idist.get_rank()
torch.manual_seed(12 + rank)
n_iters = 5
batch_size = 8
device = idist.device()
if metric_device == "process_device":
metric_device = device if device.type != "xla" else "cpu"

# 10 items
y = torch.rand((n_iters * batch_size, 10)).to(device)
y_preds = torch.rand((n_iters * batch_size, 10)).to(device)

def update(engine, i):
return (
y_preds[i * batch_size : (i + 1) * batch_size, ...],
y[i * batch_size : (i + 1) * batch_size, ...],
)

engine = Engine(update)
NDCG(device=metric_device).attach(engine, "ndcg")

data = list(range(n_iters))
engine.run(data=data, max_epochs=num_epochs)

y_preds = idist.all_gather(y_preds)
y = idist.all_gather(y)

assert "ndcg" in engine.state.metrics
res = engine.state.metrics["ndcg"]

true_res = ndcg_score(y.cpu().numpy(), y_preds.cpu().numpy())

tol = 1e-3 if device.type == "xla" else 1e-4 # Isn't better to ask `distributed` about backend info?

assert pytest.approx(res, abs=tol) == true_res


@pytest.mark.parametrize("metric_device", [torch.device("cpu"), "process_device"])
def test_distrib_accumulator_device(distributed, metric_device):
device = idist.device()
if metric_device == "process_device":
metric_device = torch.device(device if device.type != "xla" else "cpu")

ndcg = NDCG(device=metric_device)

y_pred = torch.rand((2, 10)).to(device)
y = torch.rand((2, 10)).to(device)
ndcg.update((y_pred, y))

dev = ndcg.ndcg.device
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"