Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions pytorch_forecasting/data/timeseries/_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -1860,7 +1860,7 @@ def _construct_index(self, data: pd.DataFrame, predict_mode: bool) -> pd.DataFra
)
assert len(df_index) > 0, msg

return df_index
return df_index.to_records(index=True)

def filter(self, filter_func: Callable, copy: bool = True) -> TimeSeriesDataType:
"""Filter subsequences in dataset.
Expand Down Expand Up @@ -1909,8 +1909,8 @@ def decoded_index(self) -> pd.DataFrame:
pd.DataFrame: index that can be understood in terms of original data
"""
# get dataframe to filter
index_start = self.index["index_start"].to_numpy()
index_last = self.index["index_end"].to_numpy()
index_start = self.index["index_start"]
index_last = self.index["index_end"]
index = (
# get group ids in order of index
pd.DataFrame(
Expand Down Expand Up @@ -2091,16 +2091,20 @@ def __getitem__(self, idx: int) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
Returns:
tuple[dict[str, torch.Tensor], torch.Tensor]: x and y for model
"""
index = self.index.iloc[idx]
index = self.index[idx]

# get index data
index_start = index.index_start
index_end = index.index_end
index_sequence_length = index.sequence_length
# slice data based on index
idx_slice = slice(index.index_start, index.index_end + 1)
idx_slice = slice(index_start, index_end + 1)

data_cont = self.data["reals"][idx_slice].clone()
data_cat = self.data["categoricals"][idx_slice].clone()
time = self.data["time"][idx_slice].clone()
target = [d[idx_slice].clone() for d in self.data["target"]]
groups = self.data["groups"][index.index_start].clone()
groups = self.data["groups"][index_start].clone()
if self.data["weight"] is None:
weight = None
else:
Expand All @@ -2112,7 +2116,7 @@ def __getitem__(self, idx: int) -> tuple[dict[str, torch.Tensor], torch.Tensor]:

# fill in missing values (if not all time indices are specified)
sequence_length = len(time)
if sequence_length < index.sequence_length:
if sequence_length < index_sequence_length:
assert (
self.allow_missing_timesteps
), "allow_missing_timesteps should be True if sequences have gaps"
Expand Down
Loading