diff --git a/prts/base/time_series_metrics.py b/prts/base/time_series_metrics.py index 9ee224d..258c435 100644 --- a/prts/base/time_series_metrics.py +++ b/prts/base/time_series_metrics.py @@ -127,8 +127,12 @@ def _shift(self, arr, num, fill_value=np.nan): def _prepare_data(self, values_real, values_pred): assert len(values_real) == len(values_pred) - assert np.allclose(np.unique(values_real), np.array([0, 1])) - assert np.allclose(np.unique(values_pred), np.array([0, 1])) + assert np.allclose(np.unique(values_real), np.array([0, 1])) or np.allclose( + np.unique(values_real), np.array([1]) + ) + assert np.allclose(np.unique(values_pred), np.array([0, 1])) or np.allclose( + np.unique(values_pred), np.array([1]) + ) predicted_anomalies_ = np.argwhere(values_pred == 1).ravel() predicted_anomalies_shift_forward = self._shift(predicted_anomalies_, 1, fill_value=predicted_anomalies_[0]) diff --git a/tests/test_precision.py b/tests/test_precision.py index e1d6d3d..1c0a79f 100644 --- a/tests/test_precision.py +++ b/tests/test_precision.py @@ -135,3 +135,23 @@ def test_precision_function_with_invalid_bias(self): with self.assertRaises(Exception): ts_precision(real, pred, bias="Invalid") + + def test_precision_function_with_all_zeros(self): + """Test of ts_precision function with all zero values + """ + + real = np.array([0, 0, 0, 0, 0]) + pred = np.array([0, 0, 0, 0, 0]) + + with self.assertRaises(Exception): + ts_precision(real, pred) + + def test_precision_function_with_all_ones(self): + """Test of ts_precision function with all zero values + """ + + real = np.array([1, 1, 1, 1, 1]) + pred = np.array([1, 1, 1, 1, 1]) + + self.assertEqual(ts_precision(real, pred), 1.0) +