Skip to content

Commit 3ba1abc

Browse files
authored
feature(STEF-2717): save last valid rolling aggregate during training (#811)
* feature(STEF-2717): save last valid rolling aggregate during training * add backwards compatibility for RollingAggregatesAdder and SampleWeighter * remove backwards compatibility for SampleWeighter (other branch) * remove unused imports
1 parent 28ca5e1 commit 3ba1abc

2 files changed

Lines changed: 193 additions & 43 deletions

File tree

packages/openstef-models/src/openstef_models/transforms/time_domain/rolling_aggregates_adder.py

Lines changed: 77 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,19 @@
22
#
33
# SPDX-License-Identifier: MPL-2.0
44

5-
"""Transform for extracting trend features from time series data.
6-
7-
This module provides functionality to compute trend-based features that capture
8-
long-term patterns and movements in time series data, helping improve forecasting
9-
accuracy by identifying underlying trends.
10-
"""
5+
"""Transform for adding rolling aggregate features to time series data."""
116

127
import logging
138
from datetime import timedelta
14-
from typing import Literal, cast, override
9+
from typing import Any, Literal, override
1510

1611
import pandas as pd
1712
from pydantic import Field, PrivateAttr
1813

1914
from openstef_core.base_model import BaseConfig
2015
from openstef_core.datasets import TimeSeriesDataset
2116
from openstef_core.datasets.validation import validate_required_columns
17+
from openstef_core.exceptions import NotFittedError
2218
from openstef_core.transforms import TimeSeriesTransform
2319
from openstef_core.types import LeadTime
2420
from openstef_core.utils import timedelta_to_isoformat
@@ -29,11 +25,13 @@
2925
class RollingAggregatesAdder(BaseConfig, TimeSeriesTransform):
3026
"""Transform that adds rolling aggregate features to time series data.
3127
32-
This transform computes rolling aggregate statistics (e.g., mean, median, min, max)
28+
Computes rolling aggregate statistics (e.g., mean, median, min, max)
3329
over a specified rolling window and adds these as new features to the dataset.
34-
It is useful for capturing recent trends and patterns in the data.
30+
It is useful for capturing recent trends and patterns in the data. Handles
31+
missing target data during inference via a fallback strategy:
3532
36-
The rolling aggregates are computed on the specified columns of the dataset.
33+
1. Forward-fill from last computed aggregate
34+
2. Use last valid aggregate from training
3735
3836
Example:
3937
>>> import pandas as pd
@@ -55,6 +53,7 @@ class RollingAggregatesAdder(BaseConfig, TimeSeriesTransform):
5553
... aggregation_functions=["mean", "max"],
5654
... horizons=[LeadTime.from_string("PT36H")],
5755
... )
56+
>>> transform.fit(dataset)
5857
>>> transformed_dataset = transform.transform(dataset)
5958
>>> result = transformed_dataset.data[['rolling_mean_load_PT2H', 'rolling_max_load_PT2H']]
6059
>>> print(result.round(1).head(3))
@@ -65,62 +64,97 @@ class RollingAggregatesAdder(BaseConfig, TimeSeriesTransform):
6564
2025-01-01 02:00:00 115.0 120.0
6665
"""
6766

68-
feature: str = Field(
69-
description="Feature to compute rolling aggregates for.",
70-
)
71-
horizons: list[LeadTime] = Field(
72-
description="List of forecast horizons.",
73-
min_length=1,
74-
)
67+
feature: str = Field(description="Feature to compute rolling aggregates for.")
68+
horizons: list[LeadTime] = Field(description="List of forecast horizons.", min_length=1)
7569
rolling_window_size: timedelta = Field(
7670
default=timedelta(hours=24),
7771
description="Rolling window size for the aggregation.",
7872
)
7973
aggregation_functions: list[AggregationFunction] = Field(
8074
default_factory=lambda: ["median", "min", "max"],
81-
description="List of aggregation functions to compute over the rolling window. ",
75+
description="Aggregation functions to compute over the rolling window.",
8276
)
8377

84-
_logger: logging.Logger = PrivateAttr(default=logging.getLogger(__name__))
85-
86-
def _transform_pandas(self, df: pd.DataFrame) -> pd.DataFrame:
87-
rolling_df = cast(
88-
pd.DataFrame,
89-
df[self.feature].dropna().rolling(window=self.rolling_window_size).agg(self.aggregation_functions), # pyright: ignore[reportUnknownMemberType, reportCallIssue, reportArgumentType]
90-
)
91-
# Fill missing values with the last known value
92-
rolling_df = rolling_df.reindex(df.index).ffill()
78+
_logger: logging.Logger = PrivateAttr(default_factory=lambda: logging.getLogger(__name__))
79+
_last_valid_aggregates: dict[str, float] = PrivateAttr(default_factory=dict[str, float])
80+
_is_fitted: bool = PrivateAttr(default=False)
9381

82+
def _make_column_name(self, func: AggregationFunction) -> str:
9483
suffix = timedelta_to_isoformat(td=self.rolling_window_size)
95-
rolling_df = rolling_df.rename(
96-
columns={func: f"rolling_{func}_{self.feature}_{suffix}" for func in self.aggregation_functions}
97-
)
84+
return f"rolling_{func}_{self.feature}_{suffix}"
9885

99-
return pd.concat([df, rolling_df], axis=1)
86+
def _compute_rolling_aggregates(self, series: pd.Series) -> pd.DataFrame:
87+
return series.dropna().rolling(window=self.rolling_window_size).agg(self.aggregation_functions) # type: ignore[return-value]
88+
89+
@override
90+
def fit(self, data: TimeSeriesDataset) -> None:
91+
"""Compute and store last valid aggregates from training data for fallback."""
92+
validate_required_columns(df=data.data, required_columns=[self.feature])
93+
94+
rolling_df = self._compute_rolling_aggregates(data.data[self.feature])
95+
96+
for func in self.aggregation_functions:
97+
valid_rows = rolling_df[func].dropna()
98+
if not valid_rows.empty:
99+
self._last_valid_aggregates[self._make_column_name(func)] = float(valid_rows.iloc[-1])
100+
101+
self._is_fitted = True
100102

101103
@override
102104
def transform(self, data: TimeSeriesDataset) -> TimeSeriesDataset:
103-
if len(self.aggregation_functions) == 0:
104-
self._logger.warning(
105-
"No aggregation functions specified for RollingAggregatesAdder. Returning original data."
106-
)
105+
"""Add rolling aggregate features, using fallbacks for missing values.
106+
107+
Returns:
108+
Dataset with rolling aggregate feature columns added.
109+
110+
Raises:
111+
NotFittedError: If fit() has not been called.
112+
"""
113+
if not self.aggregation_functions:
114+
self._logger.warning("No aggregation functions specified. Returning original data.")
107115
return data
108116

109117
if len(self.horizons) > 1:
110-
self._logger.warning(
111-
"Multiple horizons for RollingAggregatesAdder is not yet supported. Returning original data."
112-
)
118+
self._logger.warning("Multiple horizons not yet supported. Returning original data.")
113119
return data
114120

115121
validate_required_columns(df=data.data, required_columns=[self.feature])
116-
return data.pipe_pandas(self._transform_pandas)
122+
123+
if not self._is_fitted:
124+
raise NotFittedError(self.__class__.__name__)
125+
126+
# Compute rolling aggregates and apply fallback for missing values
127+
result_df = self._compute_and_apply_fallback(data.data)
128+
return data.copy_with(result_df)
129+
130+
def _compute_and_apply_fallback(self, df: pd.DataFrame) -> pd.DataFrame:
131+
rolling_df = self._compute_rolling_aggregates(df[self.feature])
132+
rolling_df = rolling_df.reindex(df.index).ffill()
133+
134+
# Rename columns and apply last valid fallback
135+
column_mapping = {func: self._make_column_name(func) for func in self.aggregation_functions}
136+
rolling_df = rolling_df.rename(columns=column_mapping)
137+
138+
for col in column_mapping.values():
139+
if col in self._last_valid_aggregates:
140+
rolling_df[col] = rolling_df[col].fillna(self._last_valid_aggregates[col]) # pyright: ignore[reportUnknownMemberType]
141+
142+
if rolling_df[col].isna().any():
143+
self._logger.warning("Column '%s' has NaN values after fallback.", col)
144+
145+
return pd.concat([df, rolling_df], axis=1)
117146

118147
@override
119148
def features_added(self) -> list[str]:
120-
return [
121-
f"rolling_{func}_{self.feature}_{timedelta_to_isoformat(self.rolling_window_size)}"
122-
for func in self.aggregation_functions
123-
]
149+
return [self._make_column_name(func) for func in self.aggregation_functions]
150+
151+
@override
152+
def __setstate__(self, state: Any) -> None: # TODO(#799): delete after stable release
153+
if "_last_valid_aggregates" not in state["__pydantic_private__"]:
154+
state["__pydantic_private__"]["_last_valid_aggregates"] = {}
155+
if "_is_fitted" not in state["__pydantic_private__"]:
156+
state["__pydantic_private__"]["_is_fitted"] = True
157+
return super().__setstate__(state)
124158

125159

126160
__all__ = ["RollingAggregatesAdder"]

packages/openstef-models/tests/unit/transforms/time_domain/test_rolling_aggregates_adder.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def test_rolling_aggregate_features_basic():
3131
)
3232

3333
# Act
34+
transform.fit(dataset)
3435
result = transform.transform(dataset)
3536

3637
# Assert
@@ -73,6 +74,7 @@ def test_rolling_aggregate_features_with_nan():
7374
)
7475

7576
# Act
77+
transform.fit(dataset)
7678
result = transform.transform(dataset)
7779

7880
# Assert
@@ -99,10 +101,123 @@ def test_rolling_aggregate_features_missing_column_raises_error():
99101
)
100102

101103
# Act & Assert
104+
with pytest.raises(MissingColumnsError, match="Missing required columns"):
105+
transform.fit(dataset)
106+
102107
with pytest.raises(MissingColumnsError, match="Missing required columns"):
103108
transform.transform(dataset)
104109

105110

111+
def test_rolling_aggregate_features_empty_feature_on_fit():
112+
"""Test that transform applies fallback strategy when feature is fully missing during inference."""
113+
# Arrange
114+
train_data = pd.DataFrame(
115+
{"load": [np.nan, np.nan, np.nan]},
116+
index=pd.date_range("2023-01-01 00:00:00", periods=3, freq="1h"),
117+
)
118+
train_dataset = TimeSeriesDataset(train_data, sample_interval=timedelta(hours=1))
119+
120+
transform = RollingAggregatesAdder(
121+
feature="load",
122+
rolling_window_size=timedelta(hours=2),
123+
aggregation_functions=["mean"],
124+
horizons=[LeadTime.from_string("PT36H")],
125+
)
126+
127+
# Act
128+
transform.fit(train_dataset)
129+
result = transform.transform(train_dataset)
130+
131+
# Assert
132+
assert "rolling_mean_load_PT2H" in result.data.columns
133+
assert result.data["rolling_mean_load_PT2H"].isna().all()
134+
135+
136+
def test_rolling_aggregate_features_partial_missing_during_inference():
137+
"""Test that transform computes fresh aggregates when recent data is available."""
138+
# Arrange - training data
139+
train_data = pd.DataFrame(
140+
{"load": [10.0, 20.0, 30.0]},
141+
index=pd.date_range("2023-01-01 00:00:00", periods=3, freq="1h"),
142+
)
143+
train_dataset = TimeSeriesDataset(train_data, sample_interval=timedelta(hours=1))
144+
145+
# Inference data: some recent values available, then NaN for forecast horizon
146+
test_data = pd.DataFrame(
147+
{"load": [40.0, 50.0, np.nan, np.nan]},
148+
index=pd.date_range("2023-01-01 03:00:00", periods=4, freq="1h"),
149+
)
150+
test_dataset = TimeSeriesDataset(test_data, sample_interval=timedelta(hours=1))
151+
152+
transform = RollingAggregatesAdder(
153+
feature="load",
154+
rolling_window_size=timedelta(hours=2),
155+
aggregation_functions=["mean", "max"],
156+
horizons=[LeadTime.from_string("PT36H")],
157+
)
158+
159+
# Act
160+
transform.fit(train_dataset)
161+
result = transform.transform(test_dataset)
162+
163+
# Assert
164+
assert not result.data["rolling_mean_load_PT2H"].isna().any()
165+
assert not result.data["rolling_max_load_PT2H"].isna().any()
166+
167+
# First row: only 40 in window → mean=40, max=40
168+
assert result.data["rolling_mean_load_PT2H"].iloc[0] == 40.0
169+
assert result.data["rolling_max_load_PT2H"].iloc[0] == 40.0
170+
171+
# Second row: [40, 50] in window → mean=45, max=50
172+
assert result.data["rolling_mean_load_PT2H"].iloc[1] == 45.0
173+
assert result.data["rolling_max_load_PT2H"].iloc[1] == 50.0
174+
175+
# Third and fourth rows: NaN target, forward-fill from last computed
176+
assert result.data["rolling_mean_load_PT2H"].iloc[2] == 45.0
177+
assert result.data["rolling_max_load_PT2H"].iloc[2] == 50.0
178+
assert result.data["rolling_mean_load_PT2H"].iloc[3] == 45.0
179+
assert result.data["rolling_max_load_PT2H"].iloc[3] == 50.0
180+
181+
182+
def test_rolling_aggregate_fallback_uses_last_valid_from_training():
183+
"""Test fallback uses last valid aggregate from training when inference data is all NaN."""
184+
# Arrange
185+
train_data = pd.DataFrame(
186+
{"load": [10.0, 20.0, 30.0, 40.0, 50.0]},
187+
index=pd.date_range("2023-01-01 00:00:00", periods=5, freq="1h"),
188+
)
189+
train_dataset = TimeSeriesDataset(train_data, sample_interval=timedelta(hours=1))
190+
191+
# Inference data with no valid target values
192+
test_data = pd.DataFrame(
193+
{"load": [np.nan, np.nan, np.nan]},
194+
index=pd.date_range("2023-01-01 03:00:00", periods=3, freq="1h"),
195+
)
196+
test_dataset = TimeSeriesDataset(test_data, sample_interval=timedelta(hours=1))
197+
198+
transform = RollingAggregatesAdder(
199+
feature="load",
200+
rolling_window_size=timedelta(hours=2),
201+
aggregation_functions=["mean", "max"],
202+
horizons=[LeadTime.from_string("PT36H")],
203+
)
204+
205+
# Act
206+
transform.fit(train_dataset)
207+
result = transform.transform(test_dataset)
208+
209+
# Assert - all values filled with last valid aggregate from training
210+
# Last valid from training: mean of [40, 50] = 45.0, max = 50.0
211+
assert "rolling_mean_load_PT2H" in result.data.columns
212+
assert "rolling_max_load_PT2H" in result.data.columns
213+
assert not result.data["rolling_mean_load_PT2H"].isna().any()
214+
assert not result.data["rolling_max_load_PT2H"].isna().any()
215+
216+
for i in range(3):
217+
assert result.data["rolling_mean_load_PT2H"].iloc[i] == 45.0
218+
assert result.data["rolling_max_load_PT2H"].iloc[i] == 50.0
219+
220+
106221
def test_rolling_aggregate_features_default_parameters():
107222
"""Test transform works with default parameters."""
108223
# Arrange
@@ -118,6 +233,7 @@ def test_rolling_aggregate_features_default_parameters():
118233
)
119234

120235
# Act
236+
transform.fit(dataset)
121237
result = transform.transform(dataset)
122238

123239
# Assert - default is 24-hour window with median, min, max

0 commit comments

Comments
 (0)