diff --git a/tests/ignite/metrics/test_roc_curve.py b/tests/ignite/metrics/test_roc_curve.py index 002c8975356..9eb983696ab 100644 --- a/tests/ignite/metrics/test_roc_curve.py +++ b/tests/ignite/metrics/test_roc_curve.py @@ -168,8 +168,8 @@ def update(engine, i): y = idist.all_gather(y) y_pred = idist.all_gather(y_pred) - expected_fpr, expected_tpr, expected_thresholds = roc_curve(y, y_pred) + expected_fpr, expected_tpr, expected_thresholds = roc_curve(y.cpu().numpy(), y_pred.cpu().numpy()) - assert expected_fpr == pytest.approx(fpr) - assert expected_tpr == pytest.approx(tpr) - assert expected_thresholds == pytest.approx(thresholds) + assert expected_fpr == pytest.approx(fpr.cpu().numpy()) + assert expected_tpr == pytest.approx(tpr.cpu().numpy()) + assert expected_thresholds == pytest.approx(thresholds.cpu().numpy())