Skip to content

Commit 1756521

Browse files
authored
Merge pull request #603 from WenjieDu/(feat)forecasting_timemixer
Add `pypots.forecasting.TimeMixer`
2 parents 0b8a893 + e45654f commit 1756521

File tree

16 files changed

+709
-52
lines changed

16 files changed

+709
-52
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ The paper references and links are all listed at the bottom of this file.
123123
| LLM | Time-LLM[^45] || | | | | `2024 - ICLR` |
124124
| Neural Net | TEFN🧑‍🔧[^39] ||| | | | `2024 - arXiv` |
125125
| Neural Net | FITS🧑‍🔧[^41] ||| | | | `2024 - ICLR` |
126-
| Neural Net | TimeMixer[^37] || | | | | `2024 - ICLR` |
126+
| Neural Net | TimeMixer[^37] || | | | | `2024 - ICLR` |
127127
| Neural Net | iTransformer🧑‍🔧[^24] || | | | | `2024 - ICLR` |
128128
| Neural Net | ModernTCN[^38] || | | | | `2024 - ICLR` |
129129
| Neural Net | ImputeFormer🧑‍🔧[^34] || | | | | `2024 - KDD` |

README_zh.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ PyPOTS当前支持多变量POTS数据的插补, 预测, 分类, 聚类以及异
108108
| LLM | Time-LLM[^45] || | | | | `2024 - ICLR` |
109109
| Neural Net | TEFN🧑‍🔧[^39] ||| | | | `2024 - arXiv` |
110110
| Neural Net | FITS🧑‍🔧[^41] ||| | | | `2024 - ICLR` |
111-
| Neural Net | TimeMixer[^37] || | | | | `2024 - ICLR` |
111+
| Neural Net | TimeMixer[^37] || | | | | `2024 - ICLR` |
112112
| Neural Net | iTransformer🧑‍🔧[^24] || | | | | `2024 - ICLR` |
113113
| Neural Net | ModernTCN[^38] || | | | | `2024 - ICLR` |
114114
| Neural Net | ImputeFormer🧑‍🔧[^34] || | | | | `2024 - KDD` |

docs/index.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ The paper references are all listed at the bottom of this readme file.
137137
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
138138
| Neural Net | FITS🧑‍🔧 :cite:`xu2024fits` ||| | | | ``2024 - ICLR`` |
139139
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
140-
| Neural Net | TimeMixer :cite:`wang2024timemixer` || | | | | ``2024 - ICLR`` |
140+
| Neural Net | TimeMixer :cite:`wang2024timemixer` || | | | | ``2024 - ICLR`` |
141141
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
142142
| Neural Net | iTransformer🧑‍🔧 :cite:`liu2024itransformer` || | | | | ``2024 - ICLR`` |
143143
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+

docs/pypots.forecasting.rst

+9
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,15 @@ pypots.forecasting.fits
2828
:show-inheritance:
2929
:inherited-members:
3030

31+
pypots.forecasting.timemixer
32+
------------------------------
33+
34+
.. automodule:: pypots.forecasting.timemixer
35+
:members:
36+
:undoc-members:
37+
:show-inheritance:
38+
:inherited-members:
39+
3140
pypots.forecasting.csdi
3241
------------------------------
3342

pypots/forecasting/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .csdi import CSDI
1010
from .fits import FITS
1111
from .tefn import TEFN
12+
from .timemixer import TimeMixer
1213
from .transformer import Transformer
1314

1415
__all__ = [
@@ -17,4 +18,5 @@
1718
"Transformer",
1819
"FITS",
1920
"TEFN",
21+
"TimeMixer",
2022
]

pypots/forecasting/bttf/core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def BTTF_forecast(
200200
"start_time should be larger than -1, "
201201
"namely the number of the input tensor's time steps should be larger than pred_step."
202202
)
203-
assert start_time >= np.max(time_lags), "start_time should be >= max(time_lags)"
203+
assert start_time >= np.max(time_lags), f"start_time {start_time} should be >= max(time_lags) {np.max(time_lags)}"
204204
max_count = int(np.ceil(pred_step / multi_step))
205205
tensor_hat = np.zeros((dim1, dim2, max_count * multi_step))
206206

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""
2+
3+
"""
4+
5+
# Created by Wenjie Du <[email protected]>
6+
# License: BSD-3-Clause
7+
8+
9+
from .model import TimeMixer
10+
11+
__all__ = [
12+
"TimeMixer",
13+
]

pypots/forecasting/timemixer/core.py

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""
2+
The core wrapper assembles the submodules of TimeMixer forecasting model
3+
and takes over the forward progress of the algorithm.
4+
5+
"""
6+
7+
# Created by Wenjie Du <[email protected]>
8+
# License: BSD-3-Clause
9+
10+
import torch
11+
import torch.nn as nn
12+
13+
from ...nn.functional import nonstationary_norm, nonstationary_denorm
14+
from ...nn.functional.error import calc_mse
15+
from ...nn.modules.timemixer import BackboneTimeMixer
16+
17+
18+
class _TimeMixer(nn.Module):
19+
def __init__(
20+
self,
21+
n_steps: int,
22+
n_features: int,
23+
n_pred_steps: int,
24+
n_pred_features: int,
25+
term: str,
26+
n_layers: int,
27+
d_model: int,
28+
d_ffn: int,
29+
dropout: float,
30+
top_k: int,
31+
channel_independence: bool,
32+
decomp_method: str,
33+
moving_avg: int,
34+
downsampling_layers: int,
35+
downsampling_window: int,
36+
apply_nonstationary_norm: bool = False,
37+
):
38+
super().__init__()
39+
40+
self.n_pred_steps = n_pred_steps
41+
self.n_pred_features = n_pred_features
42+
self.apply_nonstationary_norm = apply_nonstationary_norm
43+
44+
assert term in ["long", "short"], "forecasting term should be either 'long' or 'short'"
45+
self.model = BackboneTimeMixer(
46+
task_name=term + "_term_forecast",
47+
n_steps=n_steps,
48+
n_features=n_features,
49+
n_pred_steps=n_pred_steps,
50+
n_pred_features=n_pred_features,
51+
n_layers=n_layers,
52+
d_model=d_model,
53+
d_ffn=d_ffn,
54+
dropout=dropout,
55+
channel_independence=channel_independence,
56+
decomp_method=decomp_method,
57+
top_k=top_k,
58+
moving_avg=moving_avg,
59+
downsampling_layers=downsampling_layers,
60+
downsampling_window=downsampling_window,
61+
downsampling_method="avg",
62+
use_future_temporal_feature=False,
63+
)
64+
65+
# for the imputation task, the output dim is the same as input dim
66+
self.output_projection = nn.Linear(n_features, n_pred_features)
67+
68+
def forward(self, inputs: dict) -> dict:
69+
X, missing_mask = inputs["X"], inputs["missing_mask"]
70+
71+
if self.training:
72+
X_pred, X_pred_missing_mask = inputs["X_pred"], inputs["X_pred_missing_mask"]
73+
else:
74+
batch_size = X.shape[0]
75+
X_pred, X_pred_missing_mask = (
76+
torch.zeros(batch_size, self.n_pred_steps, self.n_pred_features),
77+
torch.ones(batch_size, self.n_pred_steps, self.n_pred_features),
78+
)
79+
80+
if self.apply_nonstationary_norm:
81+
# Normalization from Non-stationary Transformer
82+
X, means, stdev = nonstationary_norm(X, missing_mask)
83+
84+
# TimesMixer processing
85+
enc_out = self.model.forecast(X, missing_mask)
86+
87+
if self.apply_nonstationary_norm:
88+
# De-Normalization from Non-stationary Transformer
89+
enc_out = nonstationary_denorm(enc_out, means, stdev)
90+
91+
# project back the original data space
92+
forecasting_result = self.output_projection(enc_out)
93+
# the raw output has length = n_steps+n_pred_steps, we only need the last n_pred_steps
94+
forecasting_result = forecasting_result[:, -self.n_pred_steps :]
95+
96+
results = {
97+
"forecasting_data": forecasting_result,
98+
}
99+
100+
# if in training mode, return results with losses
101+
if self.training:
102+
# `loss` is always the item for backward propagating to update the model
103+
results["loss"] = calc_mse(X_pred, forecasting_result, X_pred_missing_mask)
104+
105+
return results

pypots/forecasting/timemixer/data.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""
2+
Dataset class for the forecasting model TimeMixer.
3+
"""
4+
5+
# Created by Wenjie Du <[email protected]>
6+
# License: BSD-3-Clause
7+
8+
from typing import Union
9+
10+
from ...data.dataset import BaseDataset
11+
12+
13+
class DatasetForTimeMixer(BaseDataset):
14+
"""Dataset for TimeMixer forecasting model."""
15+
16+
def __init__(
17+
self,
18+
data: Union[dict, str],
19+
return_X_pred=True,
20+
file_type: str = "hdf5",
21+
):
22+
super().__init__(
23+
data=data,
24+
return_X_ori=False,
25+
return_X_pred=return_X_pred,
26+
return_y=False,
27+
file_type=file_type,
28+
)

0 commit comments

Comments
 (0)