From 79f58fa2811d139ae97e74bc713a1fecb68246ce Mon Sep 17 00:00:00 2001 From: Iskander Gaba Date: Tue, 26 Dec 2023 20:21:31 +0100 Subject: [PATCH] Fixed handling acf_arr so that incorrect periods that are adjacent to correct ones are no longer considered just because the latter's ACF coefficient was set to -1 after verification --- auto_period_finder/finder.py | 19 ++++++++++--------- tests/test_finder.py | 4 ++-- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/auto_period_finder/finder.py b/auto_period_finder/finder.py index 38dbc51..7b534e5 100644 --- a/auto_period_finder/finder.py +++ b/auto_period_finder/finder.py @@ -175,17 +175,18 @@ def __find_periods( acf_kwargs: Dict[str, Union[int, bool, None]], ) -> list: periods = [] - acf_array = np.array(acf(y, nlags=len(y), **acf_kwargs)) + acf_arr = np.array(acf(y, nlags=len(y), **acf_kwargs)) + acf_arr_work = acf_arr.copy() # Eliminate the trivial seasonality period of 1 - acf_array[0] = -1 + acf_arr_work[0] = -1 while True: # i is a period candidate: It cannot be greater than half the timeseries length - i = acf_array[: acf_array.size // 2].argmax() + i = acf_arr_work[: acf_arr_work.size // 2].argmax() # No more periods left or the maximum number of periods has been found - if acf_array[i] == -1 or ( + if acf_arr_work[i] == -1 or ( max_period_count is not None and len(periods) == max_period_count ): return periods @@ -193,19 +194,19 @@ def __find_periods( # Check that i and all of its multiples are local maxima elif all( [ - acf_array[i * j - 1] < acf_array[i * j] - and acf_array[i * j] > acf_array[i * j + 1] - for j in range(1, len(acf_array) // i - 1) + acf_arr[i * j - 1] < acf_arr[i * j] + and acf_arr[i * j] > acf_arr[i * j + 1] + for j in range(1, len(acf_arr) // i - 1) ] ): # Add to period return list periods.append(i) # Ignore i and its multiplies - acf_array[[i * j for j in range(1, len(acf_array) // i)]] = -1 + acf_arr_work[[i * j for j in range(1, len(acf_arr_work) // i)]] = -1 # Not a period, ignore it else: - acf_array[i] = -1 + acf_arr_work[i] = -1 @staticmethod def __seasonality_strength(seasonal, resid): diff --git a/tests/test_finder.py b/tests/test_finder.py index 4e9dc64..d317a27 100644 --- a/tests/test_finder.py +++ b/tests/test_finder.py @@ -41,11 +41,11 @@ def test_find_strongest_period_var_wise_stl_custom(): strongest_period_var = period_finder.fit_find_strongest_var( decomposer=Decomposer.STL, decomposer_kwargs={"seasonal_deg": 0} ) - assert strongest_period_var == 132 + assert strongest_period_var == 180 def test_find_strongest_period_var_wise_moving_averages(): data = co2.load().data.resample("M").mean().ffill() period_finder = AutoPeriodFinder(data) strongest_period_var = period_finder.fit_find_strongest_var() - assert strongest_period_var == 132 + assert strongest_period_var == 176