diff --git a/tests/test_recall.py b/tests/test_recall.py index 3e8b8fd..cd6ea2e 100644 --- a/tests/test_recall.py +++ b/tests/test_recall.py @@ -61,6 +61,27 @@ def test_RecallClass_score(self): with self.assertRaises(Exception): score = obj.score(real, pred) + def test_RecallClass_update_recall(self): + """Test of _update_recall function. + """ + + # test of the normal case + real = np.array([0, 1, 0, 0, 0]) + pred = np.array([1, 1, 0, 0, 0]) + + obj = TimeSeriesRecall() + real_anomalies, predicted_anomalies = obj._prepare_data(real, pred) + + score = obj._update_recall(real_anomalies, predicted_anomalies) + self.assertEqual(score, 1.0) + + # test of the empty case + empty_real = np.array([]) + empty_pred = np.array([]) + + score = obj._update_recall(empty_real, empty_pred) + self.assertEqual(score, 0.0) + def test_recall_function(self): """Test of ts_recall function. """