22#
33# SPDX-License-Identifier: MPL-2.0
44
5- """Transform for extracting trend features from time series data.
6-
7- This module provides functionality to compute trend-based features that capture
8- long-term patterns and movements in time series data, helping improve forecasting
9- accuracy by identifying underlying trends.
10- """
5+ """Transform for adding rolling aggregate features to time series data."""
116
127import logging
138from datetime import timedelta
14- from typing import Literal , cast , override
9+ from typing import Any , Literal , override
1510
1611import pandas as pd
1712from pydantic import Field , PrivateAttr
1813
1914from openstef_core .base_model import BaseConfig
2015from openstef_core .datasets import TimeSeriesDataset
2116from openstef_core .datasets .validation import validate_required_columns
17+ from openstef_core .exceptions import NotFittedError
2218from openstef_core .transforms import TimeSeriesTransform
2319from openstef_core .types import LeadTime
2420from openstef_core .utils import timedelta_to_isoformat
2925class RollingAggregatesAdder (BaseConfig , TimeSeriesTransform ):
3026 """Transform that adds rolling aggregate features to time series data.
3127
32- This transform computes rolling aggregate statistics (e.g., mean, median, min, max)
28+ Computes rolling aggregate statistics (e.g., mean, median, min, max)
3329 over a specified rolling window and adds these as new features to the dataset.
34- It is useful for capturing recent trends and patterns in the data.
30+ It is useful for capturing recent trends and patterns in the data. Handles
31+ missing target data during inference via a fallback strategy:
3532
36- The rolling aggregates are computed on the specified columns of the dataset.
33+ 1. Forward-fill from last computed aggregate
34+ 2. Use last valid aggregate from training
3735
3836 Example:
3937 >>> import pandas as pd
@@ -55,6 +53,7 @@ class RollingAggregatesAdder(BaseConfig, TimeSeriesTransform):
5553 ... aggregation_functions=["mean", "max"],
5654 ... horizons=[LeadTime.from_string("PT36H")],
5755 ... )
56+ >>> transform.fit(dataset)
5857 >>> transformed_dataset = transform.transform(dataset)
5958 >>> result = transformed_dataset.data[['rolling_mean_load_PT2H', 'rolling_max_load_PT2H']]
6059 >>> print(result.round(1).head(3))
@@ -65,62 +64,97 @@ class RollingAggregatesAdder(BaseConfig, TimeSeriesTransform):
6564 2025-01-01 02:00:00 115.0 120.0
6665 """
6766
68- feature : str = Field (
69- description = "Feature to compute rolling aggregates for." ,
70- )
71- horizons : list [LeadTime ] = Field (
72- description = "List of forecast horizons." ,
73- min_length = 1 ,
74- )
67+ feature : str = Field (description = "Feature to compute rolling aggregates for." )
68+ horizons : list [LeadTime ] = Field (description = "List of forecast horizons." , min_length = 1 )
7569 rolling_window_size : timedelta = Field (
7670 default = timedelta (hours = 24 ),
7771 description = "Rolling window size for the aggregation." ,
7872 )
7973 aggregation_functions : list [AggregationFunction ] = Field (
8074 default_factory = lambda : ["median" , "min" , "max" ],
81- description = "List of aggregation functions to compute over the rolling window. " ,
75+ description = "Aggregation functions to compute over the rolling window." ,
8276 )
8377
84- _logger : logging .Logger = PrivateAttr (default = logging .getLogger (__name__ ))
85-
86- def _transform_pandas (self , df : pd .DataFrame ) -> pd .DataFrame :
87- rolling_df = cast (
88- pd .DataFrame ,
89- df [self .feature ].dropna ().rolling (window = self .rolling_window_size ).agg (self .aggregation_functions ), # pyright: ignore[reportUnknownMemberType, reportCallIssue, reportArgumentType]
90- )
91- # Fill missing values with the last known value
92- rolling_df = rolling_df .reindex (df .index ).ffill ()
78+ _logger : logging .Logger = PrivateAttr (default_factory = lambda : logging .getLogger (__name__ ))
79+ _last_valid_aggregates : dict [str , float ] = PrivateAttr (default_factory = dict [str , float ])
80+ _is_fitted : bool = PrivateAttr (default = False )
9381
82+ def _make_column_name (self , func : AggregationFunction ) -> str :
9483 suffix = timedelta_to_isoformat (td = self .rolling_window_size )
95- rolling_df = rolling_df .rename (
96- columns = {func : f"rolling_{ func } _{ self .feature } _{ suffix } " for func in self .aggregation_functions }
97- )
84+ return f"rolling_{ func } _{ self .feature } _{ suffix } "
9885
99- return pd .concat ([df , rolling_df ], axis = 1 )
86+ def _compute_rolling_aggregates (self , series : pd .Series ) -> pd .DataFrame :
87+ return series .dropna ().rolling (window = self .rolling_window_size ).agg (self .aggregation_functions ) # type: ignore[return-value]
88+
89+ @override
90+ def fit (self , data : TimeSeriesDataset ) -> None :
91+ """Compute and store last valid aggregates from training data for fallback."""
92+ validate_required_columns (df = data .data , required_columns = [self .feature ])
93+
94+ rolling_df = self ._compute_rolling_aggregates (data .data [self .feature ])
95+
96+ for func in self .aggregation_functions :
97+ valid_rows = rolling_df [func ].dropna ()
98+ if not valid_rows .empty :
99+ self ._last_valid_aggregates [self ._make_column_name (func )] = float (valid_rows .iloc [- 1 ])
100+
101+ self ._is_fitted = True
100102
101103 @override
102104 def transform (self , data : TimeSeriesDataset ) -> TimeSeriesDataset :
103- if len (self .aggregation_functions ) == 0 :
104- self ._logger .warning (
105- "No aggregation functions specified for RollingAggregatesAdder. Returning original data."
106- )
105+ """Add rolling aggregate features, using fallbacks for missing values.
106+
107+ Returns:
108+ Dataset with rolling aggregate feature columns added.
109+
110+ Raises:
111+ NotFittedError: If fit() has not been called.
112+ """
113+ if not self .aggregation_functions :
114+ self ._logger .warning ("No aggregation functions specified. Returning original data." )
107115 return data
108116
109117 if len (self .horizons ) > 1 :
110- self ._logger .warning (
111- "Multiple horizons for RollingAggregatesAdder is not yet supported. Returning original data."
112- )
118+ self ._logger .warning ("Multiple horizons not yet supported. Returning original data." )
113119 return data
114120
115121 validate_required_columns (df = data .data , required_columns = [self .feature ])
116- return data .pipe_pandas (self ._transform_pandas )
122+
123+ if not self ._is_fitted :
124+ raise NotFittedError (self .__class__ .__name__ )
125+
126+ # Compute rolling aggregates and apply fallback for missing values
127+ result_df = self ._compute_and_apply_fallback (data .data )
128+ return data .copy_with (result_df )
129+
130+ def _compute_and_apply_fallback (self , df : pd .DataFrame ) -> pd .DataFrame :
131+ rolling_df = self ._compute_rolling_aggregates (df [self .feature ])
132+ rolling_df = rolling_df .reindex (df .index ).ffill ()
133+
134+ # Rename columns and apply last valid fallback
135+ column_mapping = {func : self ._make_column_name (func ) for func in self .aggregation_functions }
136+ rolling_df = rolling_df .rename (columns = column_mapping )
137+
138+ for col in column_mapping .values ():
139+ if col in self ._last_valid_aggregates :
140+ rolling_df [col ] = rolling_df [col ].fillna (self ._last_valid_aggregates [col ]) # pyright: ignore[reportUnknownMemberType]
141+
142+ if rolling_df [col ].isna ().any ():
143+ self ._logger .warning ("Column '%s' has NaN values after fallback." , col )
144+
145+ return pd .concat ([df , rolling_df ], axis = 1 )
117146
118147 @override
119148 def features_added (self ) -> list [str ]:
120- return [
121- f"rolling_{ func } _{ self .feature } _{ timedelta_to_isoformat (self .rolling_window_size )} "
122- for func in self .aggregation_functions
123- ]
149+ return [self ._make_column_name (func ) for func in self .aggregation_functions ]
150+
151+ @override
152+ def __setstate__ (self , state : Any ) -> None : # TODO(#799): delete after stable release
153+ if "_last_valid_aggregates" not in state ["__pydantic_private__" ]:
154+ state ["__pydantic_private__" ]["_last_valid_aggregates" ] = {}
155+ if "_is_fitted" not in state ["__pydantic_private__" ]:
156+ state ["__pydantic_private__" ]["_is_fitted" ] = True
157+ return super ().__setstate__ (state )
124158
125159
126160__all__ = ["RollingAggregatesAdder" ]
0 commit comments