diff --git a/pytorch_forecasting/data/timeseries/_timeseries.py b/pytorch_forecasting/data/timeseries/_timeseries.py index 043a51d1e..4895cdecb 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries.py +++ b/pytorch_forecasting/data/timeseries/_timeseries.py @@ -1860,7 +1860,7 @@ def _construct_index(self, data: pd.DataFrame, predict_mode: bool) -> pd.DataFra ) assert len(df_index) > 0, msg - return df_index + return df_index.to_records(index=True) def filter(self, filter_func: Callable, copy: bool = True) -> TimeSeriesDataType: """Filter subsequences in dataset. @@ -1909,8 +1909,8 @@ def decoded_index(self) -> pd.DataFrame: pd.DataFrame: index that can be understood in terms of original data """ # get dataframe to filter - index_start = self.index["index_start"].to_numpy() - index_last = self.index["index_end"].to_numpy() + index_start = self.index["index_start"] + index_last = self.index["index_end"] index = ( # get group ids in order of index pd.DataFrame( @@ -2091,16 +2091,20 @@ def __getitem__(self, idx: int) -> tuple[dict[str, torch.Tensor], torch.Tensor]: Returns: tuple[dict[str, torch.Tensor], torch.Tensor]: x and y for model """ - index = self.index.iloc[idx] + index = self.index[idx] + # get index data + index_start = index.index_start + index_end = index.index_end + index_sequence_length = index.sequence_length # slice data based on index - idx_slice = slice(index.index_start, index.index_end + 1) + idx_slice = slice(index_start, index_end + 1) data_cont = self.data["reals"][idx_slice].clone() data_cat = self.data["categoricals"][idx_slice].clone() time = self.data["time"][idx_slice].clone() target = [d[idx_slice].clone() for d in self.data["target"]] - groups = self.data["groups"][index.index_start].clone() + groups = self.data["groups"][index_start].clone() if self.data["weight"] is None: weight = None else: @@ -2112,7 +2116,7 @@ def __getitem__(self, idx: int) -> tuple[dict[str, torch.Tensor], torch.Tensor]: # fill in missing values (if not all time indices are specified) sequence_length = len(time) - if sequence_length < index.sequence_length: + if sequence_length < index_sequence_length: assert ( self.allow_missing_timesteps ), "allow_missing_timesteps should be True if sequences have gaps"