@@ -35,6 +35,7 @@ def __init__(
35
35
downsampling_window : int ,
36
36
downsampling_method : str ,
37
37
use_future_temporal_feature : bool ,
38
+ use_norm : bool = False ,
38
39
embed = "fixed" ,
39
40
freq = "h" ,
40
41
n_classes = None ,
@@ -50,6 +51,7 @@ def __init__(
50
51
self .downsampling_window = downsampling_window
51
52
self .downsampling_layers = downsampling_layers
52
53
self .downsampling_method = downsampling_method
54
+ self .use_norm = use_norm
53
55
self .use_future_temporal_feature = use_future_temporal_feature
54
56
55
57
assert downsampling_method in ["max" , "avg" , "conv" ], "downsampling_method must be in ['max', 'avg', 'conv']"
@@ -74,12 +76,13 @@ def __init__(
74
76
)
75
77
self .preprocess = SeriesDecompositionBlock (moving_avg )
76
78
77
- if self .channel_independence == 1 :
79
+ if self .channel_independence :
78
80
self .enc_embedding = DataEmbedding (1 , d_model , embed , freq , dropout , with_pos = False )
79
81
else :
80
82
self .enc_embedding = DataEmbedding (n_features , d_model , embed , freq , dropout , with_pos = False )
81
83
82
- self .normalize_layers = torch .nn .ModuleList ([RevIN (n_features ) for _ in range (downsampling_layers + 1 )])
84
+ if self .use_norm :
85
+ self .normalize_layers = torch .nn .ModuleList ([RevIN (n_features ) for _ in range (downsampling_layers + 1 )])
83
86
84
87
if task_name == "long_term_forecast" or task_name == "short_term_forecast" :
85
88
self .predict_layers = torch .nn .ModuleList (
@@ -92,7 +95,7 @@ def __init__(
92
95
]
93
96
)
94
97
95
- if self .channel_independence == 1 :
98
+ if self .channel_independence :
96
99
self .projection_layer = nn .Linear (d_model , 1 , bias = True )
97
100
else :
98
101
self .projection_layer = nn .Linear (d_model , n_pred_features , bias = True )
@@ -117,7 +120,7 @@ def __init__(
117
120
]
118
121
)
119
122
elif task_name == "imputation" or task_name == "anomaly_detection" :
120
- if self .channel_independence == 1 :
123
+ if self .channel_independence :
121
124
self .projection_layer = nn .Linear (d_model , 1 , bias = True )
122
125
else :
123
126
self .projection_layer = nn .Linear (d_model , n_pred_features , bias = True )
@@ -137,7 +140,7 @@ def out_projection(self, dec_out, i, out_res):
137
140
return dec_out
138
141
139
142
def pre_enc (self , x_list ):
140
- if self .channel_independence == 1 :
143
+ if self .channel_independence :
141
144
return x_list , None
142
145
else :
143
146
out1_list = []
@@ -197,7 +200,7 @@ def __multi_scale_process_inputs(self, x_enc, x_mark_enc):
197
200
198
201
def forecast (self , x_enc , x_mark_enc , x_dec = None , x_mark_dec = None ):
199
202
if self .use_future_temporal_feature :
200
- if self .channel_independence == 1 :
203
+ if self .channel_independence :
201
204
B , T , N = x_enc .size ()
202
205
x_mark_dec = x_mark_dec .repeat (N , 1 , 1 )
203
206
self .x_mark_dec = self .enc_embedding (None , x_mark_dec )
@@ -211,8 +214,8 @@ def forecast(self, x_enc, x_mark_enc, x_dec=None, x_mark_dec=None):
211
214
if x_mark_enc is not None :
212
215
for i , x , x_mark in zip (range (len (x_enc )), x_enc , x_mark_enc ):
213
216
B , T , N = x .size ()
214
- x = self .normalize_layers [i ](x , x_mark , mode = "norm" )
215
- if self .channel_independence == 1 :
217
+ x = self .normalize_layers [i ](x , x_mark , mode = "norm" ) if self . use_norm else x
218
+ if self .channel_independence :
216
219
x = x .permute (0 , 2 , 1 ).contiguous ().reshape (B * N , T , 1 )
217
220
x_mark = x_mark .repeat (N , 1 , 1 )
218
221
x_list .append (x )
@@ -223,8 +226,8 @@ def forecast(self, x_enc, x_mark_enc, x_dec=None, x_mark_dec=None):
223
226
x_enc ,
224
227
):
225
228
B , T , N = x .size ()
226
- x = self .normalize_layers [i ](x , mode = "norm" )
227
- if self .channel_independence == 1 :
229
+ x = self .normalize_layers [i ](x , mode = "norm" ) if self . use_norm else x
230
+ if self .channel_independence :
228
231
x = x .permute (0 , 2 , 1 ).contiguous ().reshape (B * N , T , 1 )
229
232
x_list .append (x )
230
233
@@ -248,12 +251,12 @@ def forecast(self, x_enc, x_mark_enc, x_dec=None, x_mark_dec=None):
248
251
dec_out_list = self .future_multi_mixing (B , enc_out_list , x_list )
249
252
250
253
dec_out = torch .stack (dec_out_list , dim = - 1 ).sum (- 1 )
251
- dec_out = self .normalize_layers [0 ](dec_out , mode = "denorm" )
254
+ dec_out = self .normalize_layers [0 ](dec_out , mode = "denorm" ) if self . use_norm else dec_out
252
255
return dec_out
253
256
254
257
def future_multi_mixing (self , B , enc_out_list , x_list ):
255
258
dec_out_list = []
256
- if self .channel_independence == 1 :
259
+ if self .channel_independence :
257
260
x_list = x_list [0 ]
258
261
for i , enc_out in zip (range (len (x_list )), enc_out_list ):
259
262
dec_out = self .predict_layers [i ](enc_out .permute (0 , 2 , 1 )).permute (0 , 2 , 1 ) # align temporal dimension
@@ -310,8 +313,8 @@ def anomaly_detection(self, x_enc):
310
313
x_enc ,
311
314
):
312
315
B , T , N = x .size ()
313
- x = self .normalize_layers [i ](x , "norm" )
314
- if self .channel_independence == 1 :
316
+ x = self .normalize_layers [i ](x , "norm" ) if self . use_norm else x
317
+ if self .channel_independence :
315
318
x = x .permute (0 , 2 , 1 ).contiguous ().reshape (B * N , T , 1 )
316
319
x_list .append (x )
317
320
@@ -328,7 +331,7 @@ def anomaly_detection(self, x_enc):
328
331
dec_out = self .projection_layer (enc_out_list [0 ])
329
332
dec_out = dec_out .reshape (B , self .c_out , - 1 ).permute (0 , 2 , 1 ).contiguous ()
330
333
331
- dec_out = self .normalize_layers [0 ](dec_out , "denorm" )
334
+ dec_out = self .normalize_layers [0 ](dec_out , "denorm" ) if self . use_norm else dec_out
332
335
return dec_out
333
336
334
337
def imputation (self , x_enc , x_mark_enc ):
@@ -341,15 +344,15 @@ def imputation(self, x_enc, x_mark_enc):
341
344
if x_mark_enc is not None :
342
345
for i , x , x_mark in zip (range (len (x_enc )), x_enc , x_mark_enc ):
343
346
B , T , N = x .size ()
344
- if self .channel_independence == 1 :
347
+ if self .channel_independence :
345
348
x = x .permute (0 , 2 , 1 ).contiguous ().reshape (B * N , T , 1 )
346
349
x_list .append (x )
347
350
x_mark = x_mark .repeat (N , 1 , 1 )
348
351
x_mark_list .append (x_mark )
349
352
else :
350
353
for i , x in zip (range (len (x_enc )), x_enc ):
351
354
B , T , N = x .size ()
352
- if self .channel_independence == 1 :
355
+ if self .channel_independence :
353
356
x = x .permute (0 , 2 , 1 ).contiguous ().reshape (B * N , T , 1 )
354
357
x_list .append (x )
355
358
0 commit comments