|
| 1 | +""" |
| 2 | +The implementation of GPT4TS for the partially-observed time-series forecasting task. |
| 3 | +
|
| 4 | +""" |
| 5 | + |
| 6 | +# Created by Wenjie Du <wenjay.du@gmail.com> |
| 7 | +# License: BSD-3-Clause |
| 8 | + |
| 9 | +from typing import Union, Optional |
| 10 | + |
| 11 | +import numpy as np |
| 12 | +import torch |
| 13 | +from torch.utils.data import DataLoader |
| 14 | + |
| 15 | +from .core import _GPT4TS |
| 16 | +from .data import DatasetForGPT4TS |
| 17 | +from ..base import BaseNNForecaster |
| 18 | +from ...data.checking import key_in_data_set |
| 19 | +from ...optim.adam import Adam |
| 20 | +from ...optim.base import Optimizer |
| 21 | + |
| 22 | + |
| 23 | +class GPT4TS(BaseNNForecaster): |
| 24 | + """The PyTorch implementation of the GPT4TS forecasting model :cite:`zhou2023gpt4ts`. |
| 25 | +
|
| 26 | + Parameters |
| 27 | + ---------- |
| 28 | + n_steps : |
| 29 | + The number of time steps in the time-series data sample. |
| 30 | +
|
| 31 | + n_features : |
| 32 | + The number of features in the time-series data sample. |
| 33 | +
|
| 34 | + n_pred_steps : |
| 35 | + The number of steps in the forecasting time series. |
| 36 | +
|
| 37 | + n_pred_features : |
| 38 | + The number of features in the forecasting time series. |
| 39 | +
|
| 40 | + term : |
| 41 | + The forecasting term, which can be either 'long' or 'short'. |
| 42 | +
|
| 43 | + patch_size : |
| 44 | + The size of the patch for the patching mechanism. |
| 45 | +
|
| 46 | + patch_stride : |
| 47 | + The stride for the patching mechanism. |
| 48 | +
|
| 49 | + n_layers : |
| 50 | + The number of hidden layers to use in GPT2. |
| 51 | +
|
| 52 | + train_gpt_mlp : |
| 53 | + Whether to train the MLP in GPT2 during tuning. |
| 54 | +
|
| 55 | + d_ffn : |
| 56 | + The hidden size of the feed-forward network . |
| 57 | +
|
| 58 | + dropout : |
| 59 | + The dropout rate for the model. |
| 60 | +
|
| 61 | + embed : |
| 62 | + The embedding method for the model. |
| 63 | +
|
| 64 | + freq : |
| 65 | + The frequency of the time-series data. |
| 66 | + batch_size : |
| 67 | + The batch size for training and evaluating the model. |
| 68 | +
|
| 69 | + epochs : |
| 70 | + The number of epochs for training the model. |
| 71 | +
|
| 72 | + patience : |
| 73 | + The patience for the early-stopping mechanism. Given a positive integer, the training process will be |
| 74 | + stopped when the model does not perform better after that number of epochs. |
| 75 | + Leaving it default as None will disable the early-stopping. |
| 76 | +
|
| 77 | + train_loss_func: |
| 78 | + The customized loss function designed by users for training the model. |
| 79 | + If not given, will use the default loss as claimed in the original paper. |
| 80 | +
|
| 81 | + val_metric_func: |
| 82 | + The customized metric function designed by users for validating the model. |
| 83 | + If not given, will use the default MSE metric. |
| 84 | +
|
| 85 | + optimizer : |
| 86 | + The optimizer for model training. |
| 87 | + If not given, will use a default Adam optimizer. |
| 88 | +
|
| 89 | + num_workers : |
| 90 | + The number of subprocesses to use for data loading. |
| 91 | + `0` means data loading will be in the main process, i.e. there won't be subprocesses. |
| 92 | +
|
| 93 | + device : |
| 94 | + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. |
| 95 | + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), |
| 96 | + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. |
| 97 | + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the |
| 98 | + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). |
| 99 | + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. |
| 100 | +
|
| 101 | + saving_path : |
| 102 | + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during |
| 103 | + training into a tensorboard file). Will not save if not given. |
| 104 | +
|
| 105 | + model_saving_strategy : |
| 106 | + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. |
| 107 | + No model will be saved when it is set as None. |
| 108 | + The "best" strategy will only automatically save the best model after the training finished. |
| 109 | + The "better" strategy will automatically save the model during training whenever the model performs |
| 110 | + better than in previous epochs. |
| 111 | + The "all" strategy will save every model after each epoch training. |
| 112 | +
|
| 113 | + verbose : |
| 114 | + Whether to print out the training logs during the training process. |
| 115 | + """ |
| 116 | + |
| 117 | + def __init__( |
| 118 | + self, |
| 119 | + n_steps: int, |
| 120 | + n_features: int, |
| 121 | + n_pred_steps: int, |
| 122 | + n_pred_features: int, |
| 123 | + term: str, |
| 124 | + patch_size: int, |
| 125 | + patch_stride: int, |
| 126 | + n_layers: int, |
| 127 | + train_gpt_mlp: bool, |
| 128 | + d_ffn: int, |
| 129 | + dropout: float, |
| 130 | + embed: str = "fixed", |
| 131 | + freq="h", |
| 132 | + batch_size: int = 32, |
| 133 | + epochs: int = 100, |
| 134 | + patience: Optional[int] = None, |
| 135 | + train_loss_func: Optional[dict] = None, |
| 136 | + val_metric_func: Optional[dict] = None, |
| 137 | + optimizer: Optional[Optimizer] = Adam(), |
| 138 | + num_workers: int = 0, |
| 139 | + device: Optional[Union[str, torch.device, list]] = None, |
| 140 | + saving_path: Optional[str] = None, |
| 141 | + model_saving_strategy: Optional[str] = "best", |
| 142 | + verbose: bool = True, |
| 143 | + ): |
| 144 | + super().__init__( |
| 145 | + batch_size=batch_size, |
| 146 | + epochs=epochs, |
| 147 | + patience=patience, |
| 148 | + train_loss_func=train_loss_func, |
| 149 | + val_metric_func=val_metric_func, |
| 150 | + num_workers=num_workers, |
| 151 | + device=device, |
| 152 | + enable_amp=True, |
| 153 | + saving_path=saving_path, |
| 154 | + model_saving_strategy=model_saving_strategy, |
| 155 | + verbose=verbose, |
| 156 | + ) |
| 157 | + |
| 158 | + self.n_steps = n_steps |
| 159 | + self.n_features = n_features |
| 160 | + self.n_pred_steps = n_pred_steps |
| 161 | + self.n_pred_features = n_pred_features |
| 162 | + self.term = term |
| 163 | + self.n_layers = n_layers |
| 164 | + self.patch_size = patch_size |
| 165 | + self.patch_stride = patch_stride |
| 166 | + self.train_gpt_mlp = train_gpt_mlp |
| 167 | + self.d_ffn = d_ffn |
| 168 | + self.dropout = dropout |
| 169 | + self.embed = embed |
| 170 | + self.freq = freq |
| 171 | + |
| 172 | + # set up the model |
| 173 | + self.model = _GPT4TS( |
| 174 | + self.n_steps, |
| 175 | + self.n_features, |
| 176 | + self.n_pred_steps, |
| 177 | + self.n_pred_features, |
| 178 | + self.term, |
| 179 | + self.n_layers, |
| 180 | + self.patch_size, |
| 181 | + self.patch_stride, |
| 182 | + self.train_gpt_mlp, |
| 183 | + self.d_ffn, |
| 184 | + self.dropout, |
| 185 | + self.embed, |
| 186 | + self.freq, |
| 187 | + ) |
| 188 | + self._print_model_size() |
| 189 | + self._send_model_to_given_device() |
| 190 | + |
| 191 | + # set up the optimizer |
| 192 | + self.optimizer = optimizer |
| 193 | + self.optimizer.init_optimizer(self.model.parameters()) |
| 194 | + |
| 195 | + def _assemble_input_for_training(self, data: list) -> dict: |
| 196 | + ( |
| 197 | + indices, |
| 198 | + X, |
| 199 | + missing_mask, |
| 200 | + X_pred, |
| 201 | + X_pred_missing_mask, |
| 202 | + ) = self._send_data_to_given_device(data) |
| 203 | + |
| 204 | + inputs = { |
| 205 | + "X": X, |
| 206 | + "missing_mask": missing_mask, |
| 207 | + "X_pred": X_pred, |
| 208 | + "X_pred_missing_mask": X_pred_missing_mask, |
| 209 | + } |
| 210 | + return inputs |
| 211 | + |
| 212 | + def _assemble_input_for_validating(self, data: list) -> dict: |
| 213 | + return self._assemble_input_for_training(data) |
| 214 | + |
| 215 | + def _assemble_input_for_testing(self, data: list) -> dict: |
| 216 | + ( |
| 217 | + indices, |
| 218 | + X, |
| 219 | + missing_mask, |
| 220 | + ) = self._send_data_to_given_device(data) |
| 221 | + |
| 222 | + inputs = { |
| 223 | + "X": X, |
| 224 | + "missing_mask": missing_mask, |
| 225 | + } |
| 226 | + return inputs |
| 227 | + |
| 228 | + def fit( |
| 229 | + self, |
| 230 | + train_set: Union[dict, str], |
| 231 | + val_set: Optional[Union[dict, str]] = None, |
| 232 | + file_type: str = "hdf5", |
| 233 | + ) -> None: |
| 234 | + # Step 1: wrap the input data with classes Dataset and DataLoader |
| 235 | + training_set = DatasetForGPT4TS( |
| 236 | + train_set, |
| 237 | + file_type=file_type, |
| 238 | + ) |
| 239 | + training_loader = DataLoader( |
| 240 | + training_set, |
| 241 | + batch_size=self.batch_size, |
| 242 | + shuffle=True, |
| 243 | + num_workers=self.num_workers, |
| 244 | + ) |
| 245 | + val_loader = None |
| 246 | + if val_set is not None: |
| 247 | + if not key_in_data_set("X_pred", val_set): |
| 248 | + raise ValueError("val_set must contain 'X_pred' for model validation.") |
| 249 | + val_set = DatasetForGPT4TS( |
| 250 | + val_set, |
| 251 | + file_type=file_type, |
| 252 | + ) |
| 253 | + val_loader = DataLoader( |
| 254 | + val_set, |
| 255 | + batch_size=self.batch_size, |
| 256 | + shuffle=False, |
| 257 | + num_workers=self.num_workers, |
| 258 | + ) |
| 259 | + |
| 260 | + # Step 2: train the model and freeze it |
| 261 | + self._train_model(training_loader, val_loader) |
| 262 | + self.model.load_state_dict(self.best_model_dict) |
| 263 | + self.model.eval() # set the model as eval status to freeze it. |
| 264 | + |
| 265 | + # Step 3: save the model if necessary |
| 266 | + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") |
| 267 | + |
| 268 | + def predict( |
| 269 | + self, |
| 270 | + test_set: Union[dict, str], |
| 271 | + file_type: str = "hdf5", |
| 272 | + ) -> dict: |
| 273 | + """ |
| 274 | +
|
| 275 | + Parameters |
| 276 | + ---------- |
| 277 | + test_set : dict or str |
| 278 | + The dataset for model validating, should be a dictionary including keys as 'X' and 'y', |
| 279 | + or a path string locating a data file. |
| 280 | + If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features], |
| 281 | + which is time-series data for validating, can contain missing values, and y should be array-like of shape |
| 282 | + [n_samples], which is classification labels of X. |
| 283 | + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains |
| 284 | + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. |
| 285 | +
|
| 286 | + file_type : |
| 287 | + The type of the given file if test_set is a path string. |
| 288 | +
|
| 289 | + Returns |
| 290 | + ------- |
| 291 | + result_dict: dict |
| 292 | + Prediction results in a Python Dictionary for the given samples. |
| 293 | + It should be a dictionary including a key named 'imputation'. |
| 294 | +
|
| 295 | + """ |
| 296 | + |
| 297 | + # Step 1: wrap the input data with classes Dataset and DataLoader |
| 298 | + self.model.eval() # set the model as eval status to freeze it. |
| 299 | + test_set = DatasetForGPT4TS( |
| 300 | + test_set, |
| 301 | + return_X_pred=False, |
| 302 | + file_type=file_type, |
| 303 | + ) |
| 304 | + |
| 305 | + test_loader = DataLoader( |
| 306 | + test_set, |
| 307 | + batch_size=self.batch_size, |
| 308 | + shuffle=False, |
| 309 | + num_workers=self.num_workers, |
| 310 | + ) |
| 311 | + forecasting_collector = [] |
| 312 | + |
| 313 | + # Step 2: process the data with the model |
| 314 | + with torch.no_grad(): |
| 315 | + for idx, data in enumerate(test_loader): |
| 316 | + inputs = self._assemble_input_for_testing(data) |
| 317 | + results = self.model(inputs) |
| 318 | + forecasting_data = results["forecasting_data"] |
| 319 | + forecasting_collector.append(forecasting_data) |
| 320 | + |
| 321 | + # Step 3: output collection and return |
| 322 | + forecasting_data = torch.cat(forecasting_collector).cpu().detach().numpy() |
| 323 | + result_dict = { |
| 324 | + "forecasting": forecasting_data, # [bz, n_pred_steps, n_features] |
| 325 | + } |
| 326 | + return result_dict |
| 327 | + |
| 328 | + def forecast( |
| 329 | + self, |
| 330 | + test_set: Union[dict, str], |
| 331 | + file_type: str = "hdf5", |
| 332 | + ) -> np.ndarray: |
| 333 | + """Forecast the future of the input with the trained model. |
| 334 | +
|
| 335 | + Parameters |
| 336 | + ---------- |
| 337 | + test_set : |
| 338 | + The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps), |
| 339 | + n_features], or a path string locating a data file, e.g. h5 file. |
| 340 | +
|
| 341 | + file_type : |
| 342 | + The type of the given file if X is a path string. |
| 343 | +
|
| 344 | + Returns |
| 345 | + ------- |
| 346 | + array-like, shape [n_samples, n_pred_steps, n_features], |
| 347 | + Forecasting results. |
| 348 | + """ |
| 349 | + |
| 350 | + result_dict = self.predict(test_set, file_type=file_type) |
| 351 | + return result_dict["forecasting"] |
0 commit comments