Skip to content

Commit d899d02

Browse files
authoredMar 7, 2025··
Merge pull request #607 from WenjieDu/(fix)timemixer
Fix x and x_mark shape not consistent bug in forecasting TimeMixer
2 parents af9de8b + 1e82e65 commit d899d02

File tree

5 files changed

+38
-40
lines changed

5 files changed

+38
-40
lines changed
 

‎pypots/forecasting/timemixer/core.py

+9-13
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import torch
1111
import torch.nn as nn
1212

13-
from ...nn.functional import nonstationary_norm, nonstationary_denorm
1413
from ...nn.functional.error import calc_mse
1514
from ...nn.modules.timemixer import BackboneTimeMixer
1615

@@ -33,13 +32,12 @@ def __init__(
3332
moving_avg: int,
3433
downsampling_layers: int,
3534
downsampling_window: int,
36-
apply_nonstationary_norm: bool = False,
35+
use_norm: bool = False,
3736
):
3837
super().__init__()
3938

4039
self.n_pred_steps = n_pred_steps
4140
self.n_pred_features = n_pred_features
42-
self.apply_nonstationary_norm = apply_nonstationary_norm
4341

4442
assert term in ["long", "short"], "forecasting term should be either 'long' or 'short'"
4543
self.model = BackboneTimeMixer(
@@ -60,13 +58,15 @@ def __init__(
6058
downsampling_window=downsampling_window,
6159
downsampling_method="avg",
6260
use_future_temporal_feature=False,
61+
use_norm=use_norm,
6362
)
6463

6564
# for the imputation task, the output dim is the same as input dim
6665
self.output_projection = nn.Linear(n_features, n_pred_features)
6766

6867
def forward(self, inputs: dict) -> dict:
69-
X, missing_mask = inputs["X"], inputs["missing_mask"]
68+
X = inputs["X"]
69+
# missing_mask = inputs["missing_mask"]
7070

7171
if self.training:
7272
X_pred, X_pred_missing_mask = inputs["X_pred"], inputs["X_pred_missing_mask"]
@@ -77,16 +77,12 @@ def forward(self, inputs: dict) -> dict:
7777
torch.ones(batch_size, self.n_pred_steps, self.n_pred_features),
7878
)
7979

80-
if self.apply_nonstationary_norm:
81-
# Normalization from Non-stationary Transformer
82-
X, means, stdev = nonstationary_norm(X, missing_mask)
83-
8480
# TimesMixer processing
85-
enc_out = self.model.forecast(X, missing_mask)
86-
87-
if self.apply_nonstationary_norm:
88-
# De-Normalization from Non-stationary Transformer
89-
enc_out = nonstationary_denorm(enc_out, means, stdev)
81+
# WDU: missing_mask should not be passed into the model forward processing because the official implementation
82+
# does not accept POTS on the forecasting task. And if pass in, it will result in
83+
# x and x_mark shape not consistent bug
84+
# enc_out = self.model.forecast(X, missing_mask)
85+
enc_out = self.model.forecast(X, None)
9086

9187
# project back the original data space
9288
forecasting_result = self.output_projection(enc_out)

‎pypots/forecasting/timemixer/model.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,8 @@ class TimeMixer(BaseNNForecaster):
7070
downsampling_window :
7171
The window size for downsampling.
7272
73-
apply_nonstationary_norm :
74-
Whether to apply non-stationary normalization to the input data for TimeMixer.
75-
Please refer to :cite:`liu2022nonstationary` for details about non-stationary normalization,
76-
which is not the idea of the original TimeMixer paper. Hence, we make it optional and default not to use here.
73+
use_norm :
74+
Whether to apply RevIN to the input data for TimeMixer.
7775
7876
batch_size :
7977
The batch size for training and evaluating the model.
@@ -143,7 +141,7 @@ def __init__(
143141
moving_avg: int = 5,
144142
downsampling_layers: int = 3,
145143
downsampling_window: int = 2,
146-
apply_nonstationary_norm: bool = False,
144+
use_norm: bool = False,
147145
batch_size: int = 32,
148146
epochs: int = 100,
149147
patience: Optional[int] = None,
@@ -184,7 +182,7 @@ def __init__(
184182
self.moving_avg = moving_avg
185183
self.downsampling_layers = downsampling_layers
186184
self.downsampling_window = downsampling_window
187-
self.apply_nonstationary_norm = apply_nonstationary_norm
185+
self.use_norm = use_norm
188186

189187
# set up the model
190188
self.model = _TimeMixer(
@@ -203,7 +201,7 @@ def __init__(
203201
self.moving_avg,
204202
self.downsampling_layers,
205203
self.downsampling_window,
206-
self.apply_nonstationary_norm,
204+
self.use_norm,
207205
)
208206
self._print_model_size()
209207
self._send_model_to_given_device()

‎pypots/nn/modules/timemixer/backbone.py

+20-17
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
downsampling_window: int,
3636
downsampling_method: str,
3737
use_future_temporal_feature: bool,
38+
use_norm: bool = False,
3839
embed="fixed",
3940
freq="h",
4041
n_classes=None,
@@ -50,6 +51,7 @@ def __init__(
5051
self.downsampling_window = downsampling_window
5152
self.downsampling_layers = downsampling_layers
5253
self.downsampling_method = downsampling_method
54+
self.use_norm = use_norm
5355
self.use_future_temporal_feature = use_future_temporal_feature
5456

5557
assert downsampling_method in ["max", "avg", "conv"], "downsampling_method must be in ['max', 'avg', 'conv']"
@@ -74,12 +76,13 @@ def __init__(
7476
)
7577
self.preprocess = SeriesDecompositionBlock(moving_avg)
7678

77-
if self.channel_independence == 1:
79+
if self.channel_independence:
7880
self.enc_embedding = DataEmbedding(1, d_model, embed, freq, dropout, with_pos=False)
7981
else:
8082
self.enc_embedding = DataEmbedding(n_features, d_model, embed, freq, dropout, with_pos=False)
8183

82-
self.normalize_layers = torch.nn.ModuleList([RevIN(n_features) for _ in range(downsampling_layers + 1)])
84+
if self.use_norm:
85+
self.normalize_layers = torch.nn.ModuleList([RevIN(n_features) for _ in range(downsampling_layers + 1)])
8386

8487
if task_name == "long_term_forecast" or task_name == "short_term_forecast":
8588
self.predict_layers = torch.nn.ModuleList(
@@ -92,7 +95,7 @@ def __init__(
9295
]
9396
)
9497

95-
if self.channel_independence == 1:
98+
if self.channel_independence:
9699
self.projection_layer = nn.Linear(d_model, 1, bias=True)
97100
else:
98101
self.projection_layer = nn.Linear(d_model, n_pred_features, bias=True)
@@ -117,7 +120,7 @@ def __init__(
117120
]
118121
)
119122
elif task_name == "imputation" or task_name == "anomaly_detection":
120-
if self.channel_independence == 1:
123+
if self.channel_independence:
121124
self.projection_layer = nn.Linear(d_model, 1, bias=True)
122125
else:
123126
self.projection_layer = nn.Linear(d_model, n_pred_features, bias=True)
@@ -137,7 +140,7 @@ def out_projection(self, dec_out, i, out_res):
137140
return dec_out
138141

139142
def pre_enc(self, x_list):
140-
if self.channel_independence == 1:
143+
if self.channel_independence:
141144
return x_list, None
142145
else:
143146
out1_list = []
@@ -197,7 +200,7 @@ def __multi_scale_process_inputs(self, x_enc, x_mark_enc):
197200

198201
def forecast(self, x_enc, x_mark_enc, x_dec=None, x_mark_dec=None):
199202
if self.use_future_temporal_feature:
200-
if self.channel_independence == 1:
203+
if self.channel_independence:
201204
B, T, N = x_enc.size()
202205
x_mark_dec = x_mark_dec.repeat(N, 1, 1)
203206
self.x_mark_dec = self.enc_embedding(None, x_mark_dec)
@@ -211,8 +214,8 @@ def forecast(self, x_enc, x_mark_enc, x_dec=None, x_mark_dec=None):
211214
if x_mark_enc is not None:
212215
for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc):
213216
B, T, N = x.size()
214-
x = self.normalize_layers[i](x, x_mark, mode="norm")
215-
if self.channel_independence == 1:
217+
x = self.normalize_layers[i](x, x_mark, mode="norm") if self.use_norm else x
218+
if self.channel_independence:
216219
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
217220
x_mark = x_mark.repeat(N, 1, 1)
218221
x_list.append(x)
@@ -223,8 +226,8 @@ def forecast(self, x_enc, x_mark_enc, x_dec=None, x_mark_dec=None):
223226
x_enc,
224227
):
225228
B, T, N = x.size()
226-
x = self.normalize_layers[i](x, mode="norm")
227-
if self.channel_independence == 1:
229+
x = self.normalize_layers[i](x, mode="norm") if self.use_norm else x
230+
if self.channel_independence:
228231
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
229232
x_list.append(x)
230233

@@ -248,12 +251,12 @@ def forecast(self, x_enc, x_mark_enc, x_dec=None, x_mark_dec=None):
248251
dec_out_list = self.future_multi_mixing(B, enc_out_list, x_list)
249252

250253
dec_out = torch.stack(dec_out_list, dim=-1).sum(-1)
251-
dec_out = self.normalize_layers[0](dec_out, mode="denorm")
254+
dec_out = self.normalize_layers[0](dec_out, mode="denorm") if self.use_norm else dec_out
252255
return dec_out
253256

254257
def future_multi_mixing(self, B, enc_out_list, x_list):
255258
dec_out_list = []
256-
if self.channel_independence == 1:
259+
if self.channel_independence:
257260
x_list = x_list[0]
258261
for i, enc_out in zip(range(len(x_list)), enc_out_list):
259262
dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute(0, 2, 1) # align temporal dimension
@@ -310,8 +313,8 @@ def anomaly_detection(self, x_enc):
310313
x_enc,
311314
):
312315
B, T, N = x.size()
313-
x = self.normalize_layers[i](x, "norm")
314-
if self.channel_independence == 1:
316+
x = self.normalize_layers[i](x, "norm") if self.use_norm else x
317+
if self.channel_independence:
315318
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
316319
x_list.append(x)
317320

@@ -328,7 +331,7 @@ def anomaly_detection(self, x_enc):
328331
dec_out = self.projection_layer(enc_out_list[0])
329332
dec_out = dec_out.reshape(B, self.c_out, -1).permute(0, 2, 1).contiguous()
330333

331-
dec_out = self.normalize_layers[0](dec_out, "denorm")
334+
dec_out = self.normalize_layers[0](dec_out, "denorm") if self.use_norm else dec_out
332335
return dec_out
333336

334337
def imputation(self, x_enc, x_mark_enc):
@@ -341,15 +344,15 @@ def imputation(self, x_enc, x_mark_enc):
341344
if x_mark_enc is not None:
342345
for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc):
343346
B, T, N = x.size()
344-
if self.channel_independence == 1:
347+
if self.channel_independence:
345348
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
346349
x_list.append(x)
347350
x_mark = x_mark.repeat(N, 1, 1)
348351
x_mark_list.append(x_mark)
349352
else:
350353
for i, x in zip(range(len(x_enc)), x_enc):
351354
B, T, N = x.size()
352-
if self.channel_independence == 1:
355+
if self.channel_independence:
353356
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
354357
x_list.append(x)
355358

‎tests/forecasting/timemixer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ class TestTimeMixer(unittest.TestCase):
5454
d_model=32,
5555
d_ffn=32,
5656
moving_avg=25,
57-
downsampling_window=1,
57+
downsampling_window=2,
58+
use_norm=True,
5859
dropout=0.1,
5960
epochs=EPOCHS,
6061
saving_path=saving_path,

‎tests/global_test_config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
# set the number of epochs for all model training
2121
EPOCHS = 2
2222
# set the number of prediction steps for forecasting models
23-
N_STEPS = 12
24-
N_PRED_STEPS = 3
23+
N_STEPS = 14
24+
N_PRED_STEPS = 2
2525
N_FEATURES = 5
2626
# tensorboard and model files saving directory
2727
RESULT_SAVING_DIR = "testing_results"

0 commit comments

Comments
 (0)
Please sign in to comment.