Skip to content

Commit

Permalink
Merge pull request #71 from CompML/feature/#70/fix_prepare_data_assert
Browse files Browse the repository at this point in the history
fix _prepare_data
  • Loading branch information
nocotan authored Feb 17, 2021
2 parents dcd5764 + 4047c96 commit 0ec80c1
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
8 changes: 6 additions & 2 deletions prts/base/time_series_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,12 @@ def _shift(self, arr, num, fill_value=np.nan):
def _prepare_data(self, values_real, values_pred):

assert len(values_real) == len(values_pred)
assert np.allclose(np.unique(values_real), np.array([0, 1]))
assert np.allclose(np.unique(values_pred), np.array([0, 1]))
assert np.allclose(np.unique(values_real), np.array([0, 1])) or np.allclose(
np.unique(values_real), np.array([1])
)
assert np.allclose(np.unique(values_pred), np.array([0, 1])) or np.allclose(
np.unique(values_pred), np.array([1])
)

predicted_anomalies_ = np.argwhere(values_pred == 1).ravel()
predicted_anomalies_shift_forward = self._shift(predicted_anomalies_, 1, fill_value=predicted_anomalies_[0])
Expand Down
20 changes: 20 additions & 0 deletions tests/test_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,23 @@ def test_precision_function_with_invalid_bias(self):

with self.assertRaises(Exception):
ts_precision(real, pred, bias="Invalid")

def test_precision_function_with_all_zeros(self):
"""Test of ts_precision function with all zero values
"""

real = np.array([0, 0, 0, 0, 0])
pred = np.array([0, 0, 0, 0, 0])

with self.assertRaises(Exception):
ts_precision(real, pred)

def test_precision_function_with_all_ones(self):
"""Test of ts_precision function with all zero values
"""

real = np.array([1, 1, 1, 1, 1])
pred = np.array([1, 1, 1, 1, 1])

self.assertEqual(ts_precision(real, pred), 1.0)

0 comments on commit 0ec80c1

Please sign in to comment.