From a645adf6f2a38dc9c7503f982f824020292ae098 Mon Sep 17 00:00:00 2001 From: Jasper Zschiegner Date: Tue, 23 May 2023 12:58:31 +0200 Subject: [PATCH 1/2] Add resampling to PandasDataset. --- src/gluonts/dataset/pandas.py | 100 ++++++++++++---------------------- 1 file changed, 34 insertions(+), 66 deletions(-) diff --git a/src/gluonts/dataset/pandas.py b/src/gluonts/dataset/pandas.py index 61202022ba..68d625bd60 100644 --- a/src/gluonts/dataset/pandas.py +++ b/src/gluonts/dataset/pandas.py @@ -29,6 +29,27 @@ logger = logging.getLogger(__name__) +def norm_dataframe(df, timestamp=None, freq=None, agg=np.sum): + if freq is None: + if not hasattr(df.index, "freq"): + raise ValueError( + "DataFrame index has no ``freq`` and not ``freq`` was passed." + ) + + freq = df.index.freq + + if timestamp is not None: + df.index = pd.DatetimeIndex(df[timestamp]).to_period(freq=freq) + + elif not isinstance(df.index, pd.PeriodIndex): + df = df.to_period(freq) + + if not is_uniform(df.index): + df = df.resample(freq).agg(agg) + + return df + + @dataclass class PandasDataset: """ @@ -66,12 +87,6 @@ class PandasDataset: future_length For target and past dynamic features last ``future_length`` elements are removed when iterating over the data set. - unchecked - Whether consistency checks on indexes should be skipped. - (Default: ``False``) - assume_sorted - Whether to assume that indexes are sorted by time, and skip sorting. - (Default: ``False``) """ dataframes: InitVar[ @@ -93,8 +108,6 @@ class PandasDataset: freq: Optional[str] = None static_features: InitVar[Optional[pd.DataFrame]] = None future_length: int = 0 - unchecked: bool = False - assume_sorted: bool = False dtype: Type = np.float32 _data_entries: SizedIterable = field(init=False) _static_reals: pd.DataFrame = field(init=False) @@ -111,13 +124,6 @@ def __post_init__(self, dataframes, static_features): self._data_entries = StarMap(self._pair_to_dataentry, pairs) - if self.freq is None: - assert ( - self.timestamp is None - ), "You need to provide `freq` along with `timestamp`" - - self.freq = infer_freq(first(pairs)[1].index) - static_features = Maybe(static_features).unwrap_or_else(pd.DataFrame) object_columns = static_features.select_dtypes( @@ -159,36 +165,17 @@ def num_past_feat_dynamic_real(self) -> int: @property def static_cardinalities(self): - return self._static_cats.max(axis=1).values + 1 + return self._static_cats.max(axis=1).to_numpy() + 1 def _pair_to_dataentry(self, item_id, df) -> DataEntry: if isinstance(df, pd.Series): df = df.to_frame(name=self.target) - if self.timestamp: - df.index = pd.DatetimeIndex(df[self.timestamp]).to_period( - freq=self.freq - ) - - if not isinstance(df.index, pd.PeriodIndex): - df = df.to_period(freq=self.freq) + df = norm_dataframe(df, freq=self.freq, timestamp=self.timestamp) - if not self.assume_sorted: - df.sort_index(inplace=True) + entry = {"start": df.index[0]} - if not self.unchecked: - assert is_uniform(df.index), ( - "Dataframe index is not uniformly spaced. " - "If your dataframe contains data from multiple series in the " - 'same column ("long" format), consider constructing the ' - "dataset with `PandasDataset.from_long_dataframe` instead." - ) - - entry = { - "start": df.index[0], - } - - target = df[self.target].values + target = df[self.target].to_numpy() target = target[: len(target) - self.future_length] entry["target"] = target.T @@ -196,16 +183,18 @@ def _pair_to_dataentry(self, item_id, df) -> DataEntry: entry["item_id"] = item_id if self.num_feat_static_cat > 0: - entry["feat_static_cat"] = self._static_cats[item_id].values + entry["feat_static_cat"] = self._static_cats[item_id].to_numpy() if self.num_feat_static_real > 0: - entry["feat_static_real"] = self._static_reals[item_id].values + entry["feat_static_real"] = self._static_reals[item_id].to_numpy() if self.num_feat_dynamic_real > 0: - entry["feat_dynamic_real"] = df[self.feat_dynamic_real].values.T + entry["feat_dynamic_real"] = ( + df[self.feat_dynamic_real].to_numpy().T + ) if self.num_past_feat_dynamic_real > 0: - past_feat_dynamic_real = df[self.past_feat_dynamic_real].values + past_feat_dynamic_real = df[self.past_feat_dynamic_real].to_numpy() past_feat_dynamic_real = past_feat_dynamic_real[ : len(past_feat_dynamic_real) - self.future_length ] @@ -215,7 +204,6 @@ def _pair_to_dataentry(self, item_id, df) -> DataEntry: def __iter__(self): yield from self._data_entries - self.unchecked = True def __len__(self) -> int: return len(self._data_entries) @@ -282,19 +270,13 @@ def from_long_dataframe( Dataset containing series data from the given long dataframe. """ if timestamp is not None: - logger.info(f"Indexing data by '{timestamp}'.") dataframe.index = pd.to_datetime(dataframe[timestamp]) - - if not isinstance(dataframe.index, DatetimeIndexOpsMixin): - logger.info("Converting index into DatetimeIndex.") + elif not isinstance(dataframe.index, DatetimeIndexOpsMixin): dataframe.index = pd.to_datetime(dataframe.index) if static_feature_columns is not None: - logger.info( - f"Collecting features from columns {static_feature_columns}." - ) other_static_features = ( - dataframe[[item_id] + static_feature_columns] + dataframe[[item_id, *static_feature_columns]] .drop_duplicates() .set_index(item_id) ) @@ -324,20 +306,6 @@ def pair_with_item_id(obj: Union[tuple, pd.DataFrame, pd.Series]): raise ValueError("input must be a pair, or a pandas Series or DataFrame.") -def infer_freq(index: pd.Index) -> str: - if isinstance(index, pd.PeriodIndex): - return index.freqstr - - freq = pd.infer_freq(index) - # pandas likes to infer the `start of x` frequency, however when doing - # df.to_period("S"), it fails, so we avoid using it. It's enough to - # remove the trailing S, e.g `MS` -> `M - if len(freq) > 1 and freq.endswith("S"): - return freq[:-1] - - return freq - - def is_uniform(index: pd.PeriodIndex) -> bool: """ Check if ``index`` contains monotonically increasing periods, evenly spaced @@ -350,5 +318,5 @@ def is_uniform(index: pd.PeriodIndex) -> bool: >>> is_uniform(pd.DatetimeIndex(ts).to_period("2H")) False """ - + # Note: ``np.all(np.diff(df.index) == df.index.freq)`` is ~1000x slower. return cast(bool, np.all(np.diff(index.asi8) == index.freq.n)) From ae91b8f42fa9f592ecb278894d90dc5b96882822 Mon Sep 17 00:00:00 2001 From: Jasper Zschiegner Date: Tue, 23 May 2023 15:30:49 +0200 Subject: [PATCH 2/2] Simplify from_long_dataframe. --- src/gluonts/dataset/pandas.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/src/gluonts/dataset/pandas.py b/src/gluonts/dataset/pandas.py index 68d625bd60..5004c1f454 100644 --- a/src/gluonts/dataset/pandas.py +++ b/src/gluonts/dataset/pandas.py @@ -20,7 +20,6 @@ import numpy as np import pandas as pd from pandas.core.indexes.datetimelike import DatetimeIndexOpsMixin -from toolz import first from gluonts.maybe import Maybe from gluonts.dataset.common import DataEntry @@ -41,7 +40,10 @@ def norm_dataframe(df, timestamp=None, freq=None, agg=np.sum): if timestamp is not None: df.index = pd.DatetimeIndex(df[timestamp]).to_period(freq=freq) - elif not isinstance(df.index, pd.PeriodIndex): + elif not isinstance(df.index, DatetimeIndexOpsMixin): + df.index = pd.to_datetime(df.index) + + if not isinstance(df.index, pd.PeriodIndex): df = df.to_period(freq) if not is_uniform(df.index): @@ -122,8 +124,6 @@ def __post_init__(self, dataframes, static_features): assert isinstance(dataframes, SizedIterable) pairs = Map(pair_with_item_id, dataframes) - self._data_entries = StarMap(self._pair_to_dataentry, pairs) - static_features = Maybe(static_features).unwrap_or_else(pd.DataFrame) object_columns = static_features.select_dtypes( @@ -147,6 +147,8 @@ def __post_init__(self, dataframes, static_features): .T ) + self._data_entries = list(StarMap(self._pair_to_dataentry, pairs)) + @property def num_feat_static_cat(self) -> int: return len(self._static_cats) @@ -227,7 +229,6 @@ def from_long_dataframe( cls, dataframe: pd.DataFrame, item_id: str, - timestamp: Optional[str] = None, static_feature_columns: Optional[list[str]] = None, static_features: pd.DataFrame = pd.DataFrame(), **kwargs, @@ -269,10 +270,8 @@ def from_long_dataframe( PandasDataset Dataset containing series data from the given long dataframe. """ - if timestamp is not None: - dataframe.index = pd.to_datetime(dataframe[timestamp]) - elif not isinstance(dataframe.index, DatetimeIndexOpsMixin): - dataframe.index = pd.to_datetime(dataframe.index) + + other_static_features = pd.DataFrame() if static_feature_columns is not None: other_static_features = ( @@ -283,14 +282,9 @@ def from_long_dataframe( assert len(other_static_features) == len( dataframe[item_id].unique() ) - else: - other_static_features = pd.DataFrame() - - logger.info(f"Grouping data by '{item_id}'; this may take some time.") - pairs = list(dataframe.groupby(item_id)) return cls( - dataframes=pairs, + dataframes=dataframe.groupby(item_id), static_features=pd.concat( [static_features, other_static_features], axis=1 ), @@ -301,8 +295,10 @@ def from_long_dataframe( def pair_with_item_id(obj: Union[tuple, pd.DataFrame, pd.Series]): if isinstance(obj, tuple) and len(obj) == 2: return obj + if isinstance(obj, (pd.DataFrame, pd.Series)): return (None, obj) + raise ValueError("input must be a pair, or a pandas Series or DataFrame.")