Skip to content

Commit af9de8b

Browse files
authoredMar 7, 2025··
Merge pull request #605 from WenjieDu/(feat)forecasting_gpt4ts
Add `pypots.forecasting.GPT4TS`
2 parents 401cb43 + c10e210 commit af9de8b

File tree

19 files changed

+649
-25
lines changed

19 files changed

+649
-25
lines changed
 

‎README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ The paper references and links are all listed at the bottom of this file.
128128
| Neural Net | ModernTCN[^38] || | | | | `2024 - ICLR` |
129129
| Neural Net | ImputeFormer🧑‍🔧[^34] || | | | | `2024 - KDD` |
130130
| Neural Net | SAITS[^1] || | | | | `2023 - ESWA` |
131-
| LLM | GPT4TS[^46] || | | | | `2023 - NeurIPS` |
131+
| LLM | GPT4TS[^46] || | | | | `2023 - NeurIPS` |
132132
| Neural Net | FreTS🧑‍🔧[^23] || | | | | `2023 - NeurIPS` |
133133
| Neural Net | Koopa🧑‍🔧[^29] || | | | | `2023 - NeurIPS` |
134134
| Neural Net | Crossformer🧑‍🔧[^16] || | | | | `2023 - ICLR` |

‎README_zh.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ PyPOTS当前支持多变量POTS数据的插补, 预测, 分类, 聚类以及异
113113
| Neural Net | ModernTCN[^38] || | | | | `2024 - ICLR` |
114114
| Neural Net | ImputeFormer🧑‍🔧[^34] || | | | | `2024 - KDD` |
115115
| Neural Net | SAITS[^1] || | | | | `2023 - ESWA` |
116-
| LLM | GPT4TS[^46] || | | | | `2023 - NeurIPS` |
116+
| LLM | GPT4TS[^46] || | | | | `2023 - NeurIPS` |
117117
| Neural Net | FreTS🧑‍🔧[^23] || | | | | `2023 - NeurIPS` |
118118
| Neural Net | Koopa🧑‍🔧[^29] || | | | | `2023 - NeurIPS` |
119119
| Neural Net | Crossformer🧑‍🔧[^16] || | | | | `2023 - ICLR` |

‎docs/index.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ The paper references are all listed at the bottom of this readme file.
147147
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
148148
| Neural Net | SAITS :cite:`du2023SAITS` || | | | | ``2023 - ESWA`` |
149149
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
150-
| LLM | GPT4TS :cite:`zhou2023gpt4ts` || | | | | ``2023 - NeurIPS`` |
150+
| LLM | GPT4TS :cite:`zhou2023gpt4ts` || | | | | ``2023 - NeurIPS`` |
151151
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
152152
| Neural Net | FreTS🧑‍🔧 :cite:`yi2023frets` || | | | | ``2023 - NeurIPS`` |
153153
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+

‎docs/pypots.forecasting.rst

+9
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@ pypots.forecasting.timellm
1919
:show-inheritance:
2020
:inherited-members:
2121

22+
pypots.forecasting.gpt4ts
23+
------------------------------
24+
25+
.. automodule:: pypots.forecasting.gpt4ts
26+
:members:
27+
:undoc-members:
28+
:show-inheritance:
29+
:inherited-members:
30+
2231
pypots.forecasting.tefn
2332
------------------------------
2433

‎pypots/forecasting/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .bttf import BTTF
99
from .csdi import CSDI
1010
from .fits import FITS
11+
from .gpt4ts import GPT4TS
1112
from .tefn import TEFN
1213
from .timellm import TimeLLM
1314
from .timemixer import TimeMixer
@@ -21,4 +22,5 @@
2122
"TEFN",
2223
"TimeMixer",
2324
"TimeLLM",
25+
"GPT4TS",
2426
]

‎pypots/forecasting/gpt4ts/__init__.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""
2+
3+
"""
4+
5+
# Created by Wenjie Du <wenjay.du@gmail.com>
6+
# License: BSD-3-Clause
7+
8+
9+
from .model import GPT4TS
10+
11+
__all__ = [
12+
"GPT4TS",
13+
]

‎pypots/forecasting/gpt4ts/core.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""
2+
The core wrapper assembles the submodules of GPT4TS forecasting model
3+
and takes over the forward progress of the algorithm.
4+
5+
"""
6+
7+
# Created by Wenjie Du <wenjay.du@gmail.com>
8+
# License: BSD-3-Clause
9+
10+
from typing import Callable
11+
12+
import torch
13+
import torch.nn as nn
14+
15+
from ...nn.functional.error import calc_mse
16+
from ...nn.modules.gpt4ts import BackboneGPT4TS
17+
18+
19+
class _GPT4TS(nn.Module):
20+
def __init__(
21+
self,
22+
n_steps: int,
23+
n_features: int,
24+
n_pred_steps: int,
25+
n_pred_features: int,
26+
term: str,
27+
n_layers: int,
28+
patch_size: int,
29+
patch_stride: int,
30+
train_gpt_mlp: bool,
31+
d_ffn: int,
32+
dropout: float,
33+
embed: str,
34+
freq: str,
35+
loss_func: Callable = calc_mse,
36+
):
37+
super().__init__()
38+
39+
assert term in ["long", "short"], "forecasting term should be either 'long' or 'short'"
40+
self.n_pred_steps = n_pred_steps
41+
self.n_pred_features = n_pred_features
42+
self.loss_func = loss_func
43+
44+
self.backbone = BackboneGPT4TS(
45+
term + "_term_forecast",
46+
n_steps,
47+
n_features,
48+
n_pred_steps,
49+
n_pred_features,
50+
n_layers,
51+
patch_size,
52+
patch_stride,
53+
train_gpt_mlp,
54+
d_ffn,
55+
dropout,
56+
embed,
57+
freq,
58+
)
59+
60+
def forward(self, inputs: dict) -> dict:
61+
X, missing_mask = inputs["X"], inputs["missing_mask"]
62+
63+
if self.training:
64+
X_pred, X_pred_missing_mask = inputs["X_pred"], inputs["X_pred_missing_mask"]
65+
else:
66+
batch_size = X.shape[0]
67+
X_pred, X_pred_missing_mask = (
68+
torch.zeros(batch_size, self.n_pred_steps, self.n_pred_features),
69+
torch.ones(batch_size, self.n_pred_steps, self.n_pred_features),
70+
)
71+
72+
# GPT4TS backbone processing
73+
forecasting_result = self.backbone(X, missing_mask)
74+
# the raw output has length = n_steps+n_pred_steps, we only need the last n_pred_steps
75+
forecasting_result = forecasting_result[:, -self.n_pred_steps :]
76+
77+
results = {
78+
"forecasting_data": forecasting_result,
79+
}
80+
81+
# if in training mode, return results with losses
82+
if self.training:
83+
# `loss` is always the item for backward propagating to update the model
84+
results["loss"] = self.loss_func(X_pred, forecasting_result, X_pred_missing_mask)
85+
86+
return results

‎pypots/forecasting/gpt4ts/data.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""
2+
Dataset class for the forecasting model GPT4TS.
3+
"""
4+
5+
# Created by Wenjie Du <wenjay.du@gmail.com>
6+
# License: BSD-3-Clause
7+
8+
from typing import Union
9+
10+
from ...data.dataset import BaseDataset
11+
12+
13+
class DatasetForGPT4TS(BaseDataset):
14+
"""Dataset for GPT4TS forecasting model."""
15+
16+
def __init__(
17+
self,
18+
data: Union[dict, str],
19+
return_X_pred=True,
20+
file_type: str = "hdf5",
21+
):
22+
super().__init__(
23+
data=data,
24+
return_X_ori=False,
25+
return_X_pred=return_X_pred,
26+
return_y=False,
27+
file_type=file_type,
28+
)

‎pypots/forecasting/gpt4ts/model.py

+351
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,351 @@
1+
"""
2+
The implementation of GPT4TS for the partially-observed time-series forecasting task.
3+
4+
"""
5+
6+
# Created by Wenjie Du <wenjay.du@gmail.com>
7+
# License: BSD-3-Clause
8+
9+
from typing import Union, Optional
10+
11+
import numpy as np
12+
import torch
13+
from torch.utils.data import DataLoader
14+
15+
from .core import _GPT4TS
16+
from .data import DatasetForGPT4TS
17+
from ..base import BaseNNForecaster
18+
from ...data.checking import key_in_data_set
19+
from ...optim.adam import Adam
20+
from ...optim.base import Optimizer
21+
22+
23+
class GPT4TS(BaseNNForecaster):
24+
"""The PyTorch implementation of the GPT4TS forecasting model :cite:`zhou2023gpt4ts`.
25+
26+
Parameters
27+
----------
28+
n_steps :
29+
The number of time steps in the time-series data sample.
30+
31+
n_features :
32+
The number of features in the time-series data sample.
33+
34+
n_pred_steps :
35+
The number of steps in the forecasting time series.
36+
37+
n_pred_features :
38+
The number of features in the forecasting time series.
39+
40+
term :
41+
The forecasting term, which can be either 'long' or 'short'.
42+
43+
patch_size :
44+
The size of the patch for the patching mechanism.
45+
46+
patch_stride :
47+
The stride for the patching mechanism.
48+
49+
n_layers :
50+
The number of hidden layers to use in GPT2.
51+
52+
train_gpt_mlp :
53+
Whether to train the MLP in GPT2 during tuning.
54+
55+
d_ffn :
56+
The hidden size of the feed-forward network .
57+
58+
dropout :
59+
The dropout rate for the model.
60+
61+
embed :
62+
The embedding method for the model.
63+
64+
freq :
65+
The frequency of the time-series data.
66+
batch_size :
67+
The batch size for training and evaluating the model.
68+
69+
epochs :
70+
The number of epochs for training the model.
71+
72+
patience :
73+
The patience for the early-stopping mechanism. Given a positive integer, the training process will be
74+
stopped when the model does not perform better after that number of epochs.
75+
Leaving it default as None will disable the early-stopping.
76+
77+
train_loss_func:
78+
The customized loss function designed by users for training the model.
79+
If not given, will use the default loss as claimed in the original paper.
80+
81+
val_metric_func:
82+
The customized metric function designed by users for validating the model.
83+
If not given, will use the default MSE metric.
84+
85+
optimizer :
86+
The optimizer for model training.
87+
If not given, will use a default Adam optimizer.
88+
89+
num_workers :
90+
The number of subprocesses to use for data loading.
91+
`0` means data loading will be in the main process, i.e. there won't be subprocesses.
92+
93+
device :
94+
The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them.
95+
If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple),
96+
then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models.
97+
If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the
98+
model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices).
99+
Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future.
100+
101+
saving_path :
102+
The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during
103+
training into a tensorboard file). Will not save if not given.
104+
105+
model_saving_strategy :
106+
The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"].
107+
No model will be saved when it is set as None.
108+
The "best" strategy will only automatically save the best model after the training finished.
109+
The "better" strategy will automatically save the model during training whenever the model performs
110+
better than in previous epochs.
111+
The "all" strategy will save every model after each epoch training.
112+
113+
verbose :
114+
Whether to print out the training logs during the training process.
115+
"""
116+
117+
def __init__(
118+
self,
119+
n_steps: int,
120+
n_features: int,
121+
n_pred_steps: int,
122+
n_pred_features: int,
123+
term: str,
124+
patch_size: int,
125+
patch_stride: int,
126+
n_layers: int,
127+
train_gpt_mlp: bool,
128+
d_ffn: int,
129+
dropout: float,
130+
embed: str = "fixed",
131+
freq="h",
132+
batch_size: int = 32,
133+
epochs: int = 100,
134+
patience: Optional[int] = None,
135+
train_loss_func: Optional[dict] = None,
136+
val_metric_func: Optional[dict] = None,
137+
optimizer: Optional[Optimizer] = Adam(),
138+
num_workers: int = 0,
139+
device: Optional[Union[str, torch.device, list]] = None,
140+
saving_path: Optional[str] = None,
141+
model_saving_strategy: Optional[str] = "best",
142+
verbose: bool = True,
143+
):
144+
super().__init__(
145+
batch_size=batch_size,
146+
epochs=epochs,
147+
patience=patience,
148+
train_loss_func=train_loss_func,
149+
val_metric_func=val_metric_func,
150+
num_workers=num_workers,
151+
device=device,
152+
enable_amp=True,
153+
saving_path=saving_path,
154+
model_saving_strategy=model_saving_strategy,
155+
verbose=verbose,
156+
)
157+
158+
self.n_steps = n_steps
159+
self.n_features = n_features
160+
self.n_pred_steps = n_pred_steps
161+
self.n_pred_features = n_pred_features
162+
self.term = term
163+
self.n_layers = n_layers
164+
self.patch_size = patch_size
165+
self.patch_stride = patch_stride
166+
self.train_gpt_mlp = train_gpt_mlp
167+
self.d_ffn = d_ffn
168+
self.dropout = dropout
169+
self.embed = embed
170+
self.freq = freq
171+
172+
# set up the model
173+
self.model = _GPT4TS(
174+
self.n_steps,
175+
self.n_features,
176+
self.n_pred_steps,
177+
self.n_pred_features,
178+
self.term,
179+
self.n_layers,
180+
self.patch_size,
181+
self.patch_stride,
182+
self.train_gpt_mlp,
183+
self.d_ffn,
184+
self.dropout,
185+
self.embed,
186+
self.freq,
187+
)
188+
self._print_model_size()
189+
self._send_model_to_given_device()
190+
191+
# set up the optimizer
192+
self.optimizer = optimizer
193+
self.optimizer.init_optimizer(self.model.parameters())
194+
195+
def _assemble_input_for_training(self, data: list) -> dict:
196+
(
197+
indices,
198+
X,
199+
missing_mask,
200+
X_pred,
201+
X_pred_missing_mask,
202+
) = self._send_data_to_given_device(data)
203+
204+
inputs = {
205+
"X": X,
206+
"missing_mask": missing_mask,
207+
"X_pred": X_pred,
208+
"X_pred_missing_mask": X_pred_missing_mask,
209+
}
210+
return inputs
211+
212+
def _assemble_input_for_validating(self, data: list) -> dict:
213+
return self._assemble_input_for_training(data)
214+
215+
def _assemble_input_for_testing(self, data: list) -> dict:
216+
(
217+
indices,
218+
X,
219+
missing_mask,
220+
) = self._send_data_to_given_device(data)
221+
222+
inputs = {
223+
"X": X,
224+
"missing_mask": missing_mask,
225+
}
226+
return inputs
227+
228+
def fit(
229+
self,
230+
train_set: Union[dict, str],
231+
val_set: Optional[Union[dict, str]] = None,
232+
file_type: str = "hdf5",
233+
) -> None:
234+
# Step 1: wrap the input data with classes Dataset and DataLoader
235+
training_set = DatasetForGPT4TS(
236+
train_set,
237+
file_type=file_type,
238+
)
239+
training_loader = DataLoader(
240+
training_set,
241+
batch_size=self.batch_size,
242+
shuffle=True,
243+
num_workers=self.num_workers,
244+
)
245+
val_loader = None
246+
if val_set is not None:
247+
if not key_in_data_set("X_pred", val_set):
248+
raise ValueError("val_set must contain 'X_pred' for model validation.")
249+
val_set = DatasetForGPT4TS(
250+
val_set,
251+
file_type=file_type,
252+
)
253+
val_loader = DataLoader(
254+
val_set,
255+
batch_size=self.batch_size,
256+
shuffle=False,
257+
num_workers=self.num_workers,
258+
)
259+
260+
# Step 2: train the model and freeze it
261+
self._train_model(training_loader, val_loader)
262+
self.model.load_state_dict(self.best_model_dict)
263+
self.model.eval() # set the model as eval status to freeze it.
264+
265+
# Step 3: save the model if necessary
266+
self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
267+
268+
def predict(
269+
self,
270+
test_set: Union[dict, str],
271+
file_type: str = "hdf5",
272+
) -> dict:
273+
"""
274+
275+
Parameters
276+
----------
277+
test_set : dict or str
278+
The dataset for model validating, should be a dictionary including keys as 'X' and 'y',
279+
or a path string locating a data file.
280+
If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
281+
which is time-series data for validating, can contain missing values, and y should be array-like of shape
282+
[n_samples], which is classification labels of X.
283+
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
284+
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
285+
286+
file_type :
287+
The type of the given file if test_set is a path string.
288+
289+
Returns
290+
-------
291+
result_dict: dict
292+
Prediction results in a Python Dictionary for the given samples.
293+
It should be a dictionary including a key named 'imputation'.
294+
295+
"""
296+
297+
# Step 1: wrap the input data with classes Dataset and DataLoader
298+
self.model.eval() # set the model as eval status to freeze it.
299+
test_set = DatasetForGPT4TS(
300+
test_set,
301+
return_X_pred=False,
302+
file_type=file_type,
303+
)
304+
305+
test_loader = DataLoader(
306+
test_set,
307+
batch_size=self.batch_size,
308+
shuffle=False,
309+
num_workers=self.num_workers,
310+
)
311+
forecasting_collector = []
312+
313+
# Step 2: process the data with the model
314+
with torch.no_grad():
315+
for idx, data in enumerate(test_loader):
316+
inputs = self._assemble_input_for_testing(data)
317+
results = self.model(inputs)
318+
forecasting_data = results["forecasting_data"]
319+
forecasting_collector.append(forecasting_data)
320+
321+
# Step 3: output collection and return
322+
forecasting_data = torch.cat(forecasting_collector).cpu().detach().numpy()
323+
result_dict = {
324+
"forecasting": forecasting_data, # [bz, n_pred_steps, n_features]
325+
}
326+
return result_dict
327+
328+
def forecast(
329+
self,
330+
test_set: Union[dict, str],
331+
file_type: str = "hdf5",
332+
) -> np.ndarray:
333+
"""Forecast the future of the input with the trained model.
334+
335+
Parameters
336+
----------
337+
test_set :
338+
The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps),
339+
n_features], or a path string locating a data file, e.g. h5 file.
340+
341+
file_type :
342+
The type of the given file if X is a path string.
343+
344+
Returns
345+
-------
346+
array-like, shape [n_samples, n_pred_steps, n_features],
347+
Forecasting results.
348+
"""
349+
350+
result_dict = self.predict(test_set, file_type=file_type)
351+
return result_dict["forecasting"]

‎pypots/imputation/gpt4ts/core.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,18 @@ def __init__(
5454
def forward(self, inputs: dict) -> dict:
5555
X, missing_mask = inputs["X"], inputs["missing_mask"]
5656

57-
# TimesMixer processing
58-
dec_out = self.backbone(X, mask=missing_mask)
57+
# GPT4TS backbone processing
58+
reconstruction = self.backbone(X, mask=missing_mask)
5959

60-
imputed_data = missing_mask * X + (1 - missing_mask) * dec_out
60+
imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction
6161
results = {
6262
"imputed_data": imputed_data,
6363
}
6464

6565
# if in training mode, return results with losses
6666
if self.training:
6767
# `loss` is always the item for backward propagating to update the model
68-
loss = calc_mse(dec_out, inputs["X_ori"], inputs["indicating_mask"])
68+
loss = self.loss_func(reconstruction, inputs["X_ori"], inputs["indicating_mask"])
6969
results["loss"] = loss
7070

7171
return results

‎pypots/imputation/gpt4ts/model.py

+19-10
Original file line numberDiff line numberDiff line change
@@ -132,26 +132,35 @@ def __init__(
132132
val_metric_func=val_metric_func,
133133
num_workers=num_workers,
134134
device=device,
135+
enable_amp=True,
135136
saving_path=saving_path,
136137
model_saving_strategy=model_saving_strategy,
137138
verbose=verbose,
138139
)
139140

140141
self.n_steps = n_steps
141142
self.n_features = n_features
143+
self.n_layers = n_layers
144+
self.patch_size = patch_size
145+
self.patch_stride = patch_stride
146+
self.train_gpt_mlp = train_gpt_mlp
147+
self.d_ffn = d_ffn
148+
self.dropout = dropout
149+
self.embed = embed
150+
self.freq = freq
142151

143152
# set up the model
144153
self.model = _GPT4TS(
145-
n_steps,
146-
n_features,
147-
n_layers,
148-
patch_size,
149-
patch_stride,
150-
train_gpt_mlp,
151-
d_ffn,
152-
dropout,
153-
embed,
154-
freq,
154+
self.n_steps,
155+
self.n_features,
156+
self.n_layers,
157+
self.patch_size,
158+
self.patch_stride,
159+
self.train_gpt_mlp,
160+
self.d_ffn,
161+
self.dropout,
162+
self.embed,
163+
self.freq,
155164
)
156165
self._send_model_to_given_device()
157166
self._print_model_size()

‎pypots/nn/modules/gpt4ts/backbone.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def forecast(
141141
# enc_out = rearrange(enc_out, 'b m n p -> b n (m p)')
142142

143143
dec_out = self.gpt2(inputs_embeds=enc_out).last_hidden_state
144-
dec_out = dec_out[:, :, : self.d_ff]
144+
dec_out = dec_out[:, :, : self.d_ffn]
145145
# dec_out = dec_out.reshape(B, -1)
146146

147147
# dec_out = self.ln(dec_out)
@@ -179,7 +179,7 @@ def anomaly_detection(
179179

180180
outputs = self.gpt2(inputs_embeds=enc_out).last_hidden_state
181181

182-
outputs = outputs[:, :, : self.d_ff]
182+
outputs = outputs[:, :, : self.d_ffn]
183183
# outputs = self.ln_proj(outputs)
184184
dec_out = self.out_layer(outputs)
185185

‎tests/forecasting/csdi.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535

3636
class TestCSDI(unittest.TestCase):
37-
logger.info("Running tests for an forecasting model CSDI...")
37+
logger.info("Running tests for a forecasting model CSDI...")
3838

3939
# set the log and model saving path
4040
saving_path = os.path.join(RESULT_SAVING_DIR_FOR_FORECASTING, "CSDI")

‎tests/forecasting/fits.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434

3535
class TestFITS(unittest.TestCase):
36-
logger.info("Running tests for an forecasting model FITS...")
36+
logger.info("Running tests for a forecasting model FITS...")
3737

3838
# set the log and model saving path
3939
saving_path = os.path.join(RESULT_SAVING_DIR_FOR_FORECASTING, "FITS")

‎tests/forecasting/llms/gpt4ts.py

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""
2+
Test cases for GPT4TS forecasting model.
3+
"""
4+
5+
# Created by Wenjie Du <wenjay.du@gmail.com>
6+
# License: BSD-3-Clause
7+
8+
9+
import os.path
10+
import unittest
11+
12+
import numpy as np
13+
import pytest
14+
15+
from pypots.forecasting import GPT4TS
16+
from pypots.nn.functional import calc_mse
17+
from pypots.optim import Adam
18+
from pypots.utils.logging import logger
19+
from tests.global_test_config import (
20+
DATA,
21+
EPOCHS,
22+
DEVICE,
23+
N_PRED_STEPS,
24+
FORECASTING_TRAIN_SET,
25+
FORECASTING_VAL_SET,
26+
FORECASTING_TEST_SET,
27+
FORECASTING_H5_TRAIN_SET_PATH,
28+
FORECASTING_H5_VAL_SET_PATH,
29+
FORECASTING_H5_TEST_SET_PATH,
30+
RESULT_SAVING_DIR_FOR_FORECASTING,
31+
check_tb_and_model_checkpoints_existence,
32+
)
33+
34+
35+
class TestGPT4TS(unittest.TestCase):
36+
logger.info("Running tests for a forecasting model GPT4TS...")
37+
38+
# set the log and model saving path
39+
saving_path = os.path.join(RESULT_SAVING_DIR_FOR_FORECASTING, "GPT4TS")
40+
model_save_name = "saved_gpt4ts_model.pypots"
41+
42+
# initialize an Adam optimizer
43+
optimizer = Adam(lr=0.001, weight_decay=1e-5)
44+
45+
# initialize a GPT4TS model
46+
gpt4ts = GPT4TS(
47+
n_steps=DATA["n_steps"] - N_PRED_STEPS,
48+
n_features=DATA["n_features"],
49+
n_pred_steps=N_PRED_STEPS,
50+
n_pred_features=DATA["n_features"],
51+
term="short",
52+
patch_size=DATA["n_steps"],
53+
patch_stride=8,
54+
n_layers=2,
55+
train_gpt_mlp=True,
56+
d_ffn=128,
57+
dropout=0.1,
58+
batch_size=8,
59+
epochs=EPOCHS,
60+
saving_path=saving_path,
61+
optimizer=optimizer,
62+
device=DEVICE,
63+
)
64+
65+
@pytest.mark.xdist_group(name="forecasting-gpt4ts")
66+
def test_0_fit(self):
67+
self.gpt4ts.fit(FORECASTING_TRAIN_SET, FORECASTING_VAL_SET)
68+
69+
@pytest.mark.xdist_group(name="forecasting-gpt4ts")
70+
def test_1_forecasting(self):
71+
forecasting_X = self.gpt4ts.predict(FORECASTING_TEST_SET)["forecasting"]
72+
assert not np.isnan(
73+
forecasting_X
74+
).any(), "Output has missing values in the forecasting results that should not be."
75+
test_MSE = calc_mse(
76+
forecasting_X,
77+
FORECASTING_TEST_SET["X_pred"],
78+
~np.isnan(FORECASTING_TEST_SET["X_pred"]),
79+
)
80+
logger.info(f"GPT4TS test_MSE: {test_MSE}")
81+
82+
@pytest.mark.xdist_group(name="forecasting-gpt4ts")
83+
def test_2_parameters(self):
84+
assert hasattr(self.gpt4ts, "model") and self.gpt4ts.model is not None
85+
86+
assert hasattr(self.gpt4ts, "optimizer") and self.gpt4ts.optimizer is not None
87+
88+
assert hasattr(self.gpt4ts, "best_loss")
89+
self.assertNotEqual(self.gpt4ts.best_loss, float("inf"))
90+
91+
assert hasattr(self.gpt4ts, "best_model_dict") and self.gpt4ts.best_model_dict is not None
92+
93+
@pytest.mark.xdist_group(name="forecasting-gpt4ts")
94+
def test_3_saving_path(self):
95+
# whether the root saving dir exists, which should be created by save_log_into_tb_file
96+
assert os.path.exists(self.saving_path), f"file {self.saving_path} does not exist"
97+
98+
# check if the tensorboard file and model checkpoints exist
99+
check_tb_and_model_checkpoints_existence(self.gpt4ts)
100+
101+
# save the trained model into file, and check if the path exists
102+
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
103+
self.gpt4ts.save(saved_model_path)
104+
105+
# test loading the saved model, not necessary, but need to test
106+
self.gpt4ts.load(saved_model_path)
107+
108+
@pytest.mark.xdist_group(name="forecasting-gpt4ts")
109+
def test_4_lazy_loading(self):
110+
self.gpt4ts.fit(FORECASTING_H5_TRAIN_SET_PATH, FORECASTING_H5_VAL_SET_PATH)
111+
forecasting_results = self.gpt4ts.predict(FORECASTING_H5_TEST_SET_PATH)
112+
forecasting_X = forecasting_results["forecasting"]
113+
assert not np.isnan(
114+
forecasting_X
115+
).any(), "Output has missing values in the forecasting results that should not be."
116+
117+
test_MSE = calc_mse(
118+
forecasting_X,
119+
FORECASTING_TEST_SET["X_pred"],
120+
~np.isnan(FORECASTING_TEST_SET["X_pred"]),
121+
)
122+
logger.info(f"Lazy-loading GPT4TS test_MSE: {test_MSE}")
123+
124+
125+
if __name__ == "__main__":
126+
unittest.main()

‎tests/forecasting/llms/timellm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434

3535
class TestTimeLLM(unittest.TestCase):
36-
logger.info("Running tests for an forecasting model TimeLLM...")
36+
logger.info("Running tests for a forecasting model TimeLLM...")
3737

3838
# set the log and model saving path
3939
saving_path = os.path.join(RESULT_SAVING_DIR_FOR_FORECASTING, "TimeLLM")

‎tests/forecasting/tefn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434

3535
class TestTEFN(unittest.TestCase):
36-
logger.info("Running tests for an forecasting model TEFN...")
36+
logger.info("Running tests for a forecasting model TEFN...")
3737

3838
# set the log and model saving path
3939
saving_path = os.path.join(RESULT_SAVING_DIR_FOR_FORECASTING, "TEFN")

‎tests/forecasting/timemixer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434

3535
class TestTimeMixer(unittest.TestCase):
36-
logger.info("Running tests for an forecasting model TimeMixer...")
36+
logger.info("Running tests for a forecasting model TimeMixer...")
3737

3838
# set the log and model saving path
3939
saving_path = os.path.join(RESULT_SAVING_DIR_FOR_FORECASTING, "TimeMixer")

‎tests/forecasting/transformer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434

3535
class TestTransformer(unittest.TestCase):
36-
logger.info("Running tests for an forecasting model Transformer...")
36+
logger.info("Running tests for a forecasting model Transformer...")
3737

3838
# set the log and model saving path
3939
saving_path = os.path.join(RESULT_SAVING_DIR_FOR_FORECASTING, "Transformer")

0 commit comments

Comments
 (0)
Please sign in to comment.