Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add resampling to PandasDataset. #2882

Draft
wants to merge 2 commits into
base: dev
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 42 additions & 78 deletions src/gluonts/dataset/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import numpy as np
import pandas as pd
from pandas.core.indexes.datetimelike import DatetimeIndexOpsMixin
from toolz import first

from gluonts.maybe import Maybe
from gluonts.dataset.common import DataEntry
Expand All @@ -29,6 +28,30 @@
logger = logging.getLogger(__name__)


def norm_dataframe(df, timestamp=None, freq=None, agg=np.sum):
if freq is None:
if not hasattr(df.index, "freq"):
raise ValueError(
"DataFrame index has no ``freq`` and not ``freq`` was passed."
)

freq = df.index.freq

if timestamp is not None:
df.index = pd.DatetimeIndex(df[timestamp]).to_period(freq=freq)

elif not isinstance(df.index, DatetimeIndexOpsMixin):
df.index = pd.to_datetime(df.index)

if not isinstance(df.index, pd.PeriodIndex):
df = df.to_period(freq)

if not is_uniform(df.index):
df = df.resample(freq).agg(agg)

return df


@dataclass
class PandasDataset:
"""
Expand Down Expand Up @@ -66,12 +89,6 @@ class PandasDataset:
future_length
For target and past dynamic features last ``future_length``
elements are removed when iterating over the data set.
unchecked
Whether consistency checks on indexes should be skipped.
(Default: ``False``)
assume_sorted
Whether to assume that indexes are sorted by time, and skip sorting.
(Default: ``False``)
"""

dataframes: InitVar[
Expand All @@ -93,8 +110,6 @@ class PandasDataset:
freq: Optional[str] = None
static_features: InitVar[Optional[pd.DataFrame]] = None
future_length: int = 0
unchecked: bool = False
assume_sorted: bool = False
dtype: Type = np.float32
_data_entries: SizedIterable = field(init=False)
_static_reals: pd.DataFrame = field(init=False)
Expand All @@ -109,15 +124,6 @@ def __post_init__(self, dataframes, static_features):
assert isinstance(dataframes, SizedIterable)
pairs = Map(pair_with_item_id, dataframes)

self._data_entries = StarMap(self._pair_to_dataentry, pairs)

if self.freq is None:
assert (
self.timestamp is None
), "You need to provide `freq` along with `timestamp`"

self.freq = infer_freq(first(pairs)[1].index)

static_features = Maybe(static_features).unwrap_or_else(pd.DataFrame)

object_columns = static_features.select_dtypes(
Expand All @@ -141,6 +147,8 @@ def __post_init__(self, dataframes, static_features):
.T
)

self._data_entries = list(StarMap(self._pair_to_dataentry, pairs))

@property
def num_feat_static_cat(self) -> int:
return len(self._static_cats)
Expand All @@ -159,53 +167,36 @@ def num_past_feat_dynamic_real(self) -> int:

@property
def static_cardinalities(self):
return self._static_cats.max(axis=1).values + 1
return self._static_cats.max(axis=1).to_numpy() + 1

def _pair_to_dataentry(self, item_id, df) -> DataEntry:
if isinstance(df, pd.Series):
df = df.to_frame(name=self.target)

if self.timestamp:
df.index = pd.DatetimeIndex(df[self.timestamp]).to_period(
freq=self.freq
)

if not isinstance(df.index, pd.PeriodIndex):
df = df.to_period(freq=self.freq)
df = norm_dataframe(df, freq=self.freq, timestamp=self.timestamp)

if not self.assume_sorted:
df.sort_index(inplace=True)
entry = {"start": df.index[0]}

if not self.unchecked:
assert is_uniform(df.index), (
"Dataframe index is not uniformly spaced. "
"If your dataframe contains data from multiple series in the "
'same column ("long" format), consider constructing the '
"dataset with `PandasDataset.from_long_dataframe` instead."
)

entry = {
"start": df.index[0],
}

target = df[self.target].values
target = df[self.target].to_numpy()
target = target[: len(target) - self.future_length]
entry["target"] = target.T

if item_id is not None:
entry["item_id"] = item_id

if self.num_feat_static_cat > 0:
entry["feat_static_cat"] = self._static_cats[item_id].values
entry["feat_static_cat"] = self._static_cats[item_id].to_numpy()

if self.num_feat_static_real > 0:
entry["feat_static_real"] = self._static_reals[item_id].values
entry["feat_static_real"] = self._static_reals[item_id].to_numpy()

if self.num_feat_dynamic_real > 0:
entry["feat_dynamic_real"] = df[self.feat_dynamic_real].values.T
entry["feat_dynamic_real"] = (
df[self.feat_dynamic_real].to_numpy().T
)

if self.num_past_feat_dynamic_real > 0:
past_feat_dynamic_real = df[self.past_feat_dynamic_real].values
past_feat_dynamic_real = df[self.past_feat_dynamic_real].to_numpy()
past_feat_dynamic_real = past_feat_dynamic_real[
: len(past_feat_dynamic_real) - self.future_length
]
Expand All @@ -215,7 +206,6 @@ def _pair_to_dataentry(self, item_id, df) -> DataEntry:

def __iter__(self):
yield from self._data_entries
self.unchecked = True

def __len__(self) -> int:
return len(self._data_entries)
Expand All @@ -239,7 +229,6 @@ def from_long_dataframe(
cls,
dataframe: pd.DataFrame,
item_id: str,
timestamp: Optional[str] = None,
static_feature_columns: Optional[list[str]] = None,
static_features: pd.DataFrame = pd.DataFrame(),
**kwargs,
Expand Down Expand Up @@ -281,34 +270,21 @@ def from_long_dataframe(
PandasDataset
Dataset containing series data from the given long dataframe.
"""
if timestamp is not None:
logger.info(f"Indexing data by '{timestamp}'.")
dataframe.index = pd.to_datetime(dataframe[timestamp])

if not isinstance(dataframe.index, DatetimeIndexOpsMixin):
logger.info("Converting index into DatetimeIndex.")
dataframe.index = pd.to_datetime(dataframe.index)
other_static_features = pd.DataFrame()

if static_feature_columns is not None:
logger.info(
f"Collecting features from columns {static_feature_columns}."
)
other_static_features = (
dataframe[[item_id] + static_feature_columns]
dataframe[[item_id, *static_feature_columns]]
.drop_duplicates()
.set_index(item_id)
)
assert len(other_static_features) == len(
dataframe[item_id].unique()
)
else:
other_static_features = pd.DataFrame()

logger.info(f"Grouping data by '{item_id}'; this may take some time.")
pairs = list(dataframe.groupby(item_id))

return cls(
dataframes=pairs,
dataframes=dataframe.groupby(item_id),
static_features=pd.concat(
[static_features, other_static_features], axis=1
),
Expand All @@ -319,23 +295,11 @@ def from_long_dataframe(
def pair_with_item_id(obj: Union[tuple, pd.DataFrame, pd.Series]):
if isinstance(obj, tuple) and len(obj) == 2:
return obj

if isinstance(obj, (pd.DataFrame, pd.Series)):
return (None, obj)
raise ValueError("input must be a pair, or a pandas Series or DataFrame.")


def infer_freq(index: pd.Index) -> str:
if isinstance(index, pd.PeriodIndex):
return index.freqstr

freq = pd.infer_freq(index)
# pandas likes to infer the `start of x` frequency, however when doing
# df.to_period("<x>S"), it fails, so we avoid using it. It's enough to
# remove the trailing S, e.g `MS` -> `M
if len(freq) > 1 and freq.endswith("S"):
return freq[:-1]

return freq
raise ValueError("input must be a pair, or a pandas Series or DataFrame.")


def is_uniform(index: pd.PeriodIndex) -> bool:
Expand All @@ -350,5 +314,5 @@ def is_uniform(index: pd.PeriodIndex) -> bool:
>>> is_uniform(pd.DatetimeIndex(ts).to_period("2H"))
False
"""

# Note: ``np.all(np.diff(df.index) == df.index.freq)`` is ~1000x slower.
return cast(bool, np.all(np.diff(index.asi8) == index.freq.n))