Skip to content

Commit

Permalink
fix assert
Browse files Browse the repository at this point in the history
  • Loading branch information
ryoherisson committed Feb 17, 2021
1 parent 0bb9ef6 commit 4047c96
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions prts/base/time_series_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,12 @@ def _shift(self, arr, num, fill_value=np.nan):
def _prepare_data(self, values_real, values_pred):

assert len(values_real) == len(values_pred)
assert np.allclose(np.unique(values_real), np.array([0, 1])) or np.allclose(np.unique(values_real), np.array([1]))
assert np.allclose(np.unique(values_pred), np.array([0, 1])) or np.allclose(np.unique(values_pred), np.array([1]))
assert np.allclose(np.unique(values_real), np.array([0, 1])) or np.allclose(
np.unique(values_real), np.array([1])
)
assert np.allclose(np.unique(values_pred), np.array([0, 1])) or np.allclose(
np.unique(values_pred), np.array([1])
)

predicted_anomalies_ = np.argwhere(values_pred == 1).ravel()
predicted_anomalies_shift_forward = self._shift(predicted_anomalies_, 1, fill_value=predicted_anomalies_[0])
Expand Down

0 comments on commit 4047c96

Please sign in to comment.