From aa68c8339e7928644a63321925bd341884a108e9 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 6 Mar 2025 11:05:17 +0800 Subject: [PATCH 1/5] feat: add FITS forecasting model; --- pypots/forecasting/__init__.py | 3 + pypots/forecasting/fits/__init__.py | 13 ++ pypots/forecasting/fits/core.py | 95 +++++++++ pypots/forecasting/fits/data.py | 28 +++ pypots/forecasting/fits/model.py | 320 ++++++++++++++++++++++++++++ 5 files changed, 459 insertions(+) create mode 100644 pypots/forecasting/fits/__init__.py create mode 100644 pypots/forecasting/fits/core.py create mode 100644 pypots/forecasting/fits/data.py create mode 100644 pypots/forecasting/fits/model.py diff --git a/pypots/forecasting/__init__.py b/pypots/forecasting/__init__.py index 8466faca..f7aea979 100644 --- a/pypots/forecasting/__init__.py +++ b/pypots/forecasting/__init__.py @@ -8,9 +8,12 @@ from .bttf import BTTF from .csdi import CSDI from .transformer import Transformer +from .fits import FITS + __all__ = [ "BTTF", "CSDI", "Transformer", + "FITS", ] diff --git a/pypots/forecasting/fits/__init__.py b/pypots/forecasting/fits/__init__.py new file mode 100644 index 00000000..bee750c3 --- /dev/null +++ b/pypots/forecasting/fits/__init__.py @@ -0,0 +1,13 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .model import FITS + +__all__ = [ + "FITS", +] diff --git a/pypots/forecasting/fits/core.py b/pypots/forecasting/fits/core.py new file mode 100644 index 00000000..f4860030 --- /dev/null +++ b/pypots/forecasting/fits/core.py @@ -0,0 +1,95 @@ +""" +The core wrapper assembles the submodules of FITS forecasting model +and takes over the forward progress of the algorithm. + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch +import torch.nn as nn + +from ...nn.functional import nonstationary_norm, nonstationary_denorm +from ...nn.functional.error import calc_mse +from ...nn.modules.fits import BackboneFITS +from ...nn.modules.saits import SaitsEmbedding + + +class _FITS(nn.Module): + def __init__( + self, + n_steps: int, + n_features: int, + n_pred_steps: int, + n_pred_features: int, + cut_freq: int, + individual: bool, + apply_nonstationary_norm: bool = False, + ): + super().__init__() + + self.n_pred_steps = n_pred_steps + self.n_pred_features = n_pred_features + self.apply_nonstationary_norm = apply_nonstationary_norm + + self.saits_embedding = SaitsEmbedding( + n_features * 2, + n_features, + with_pos=False, + ) + self.backbone = BackboneFITS( + n_steps, + n_features, + n_pred_steps, + cut_freq, + individual, + ) + + # for the imputation task, the output dim is the same as input dim + self.output_projection = nn.Linear(n_features, n_pred_features) + + def forward(self, inputs: dict) -> dict: + X, missing_mask = inputs["X"], inputs["missing_mask"] + + if self.training: + X_pred, X_pred_missing_mask = inputs["X_pred"], inputs["X_pred_missing_mask"] + else: + batch_size = X.shape[0] + X_pred, X_pred_missing_mask = ( + torch.zeros(batch_size, self.n_pred_steps, self.n_pred_features), + torch.ones(batch_size, self.n_pred_steps, self.n_pred_features), + ) + + if self.apply_nonstationary_norm: + # Normalization from Non-stationary Transformer + X, means, stdev = nonstationary_norm(X, missing_mask) + + # WDU: the original FITS paper isn't proposed for imputation task. Hence the model doesn't take + # the missing mask into account, which means, in the process, the model doesn't know which part of + # the input data is missing, and this may hurt the model's imputation performance. Therefore, I apply the + # SAITS embedding method to project the concatenation of features and masks into a hidden space, as well as + # the output layers to project back from the hidden space to the original space. + enc_out = self.saits_embedding(X, missing_mask) + + # FITS encoder processing + enc_out = self.backbone(enc_out) + if self.apply_nonstationary_norm: + # De-Normalization from Non-stationary Transformer + enc_out = nonstationary_denorm(enc_out, means, stdev) + + # project back the original data space + forecasting_result = self.output_projection(enc_out) + # the raw output has length = n_steps+n_pred_steps, we only need the last n_pred_steps + forecasting_result = forecasting_result[:, -self.n_pred_steps :] + + results = { + "forecasting_data": forecasting_result, + } + + # if in training mode, return results with losses + if self.training: + # `loss` is always the item for backward propagating to update the model + results["loss"] = calc_mse(X_pred, forecasting_result, X_pred_missing_mask) + + return results diff --git a/pypots/forecasting/fits/data.py b/pypots/forecasting/fits/data.py new file mode 100644 index 00000000..2ffe6b63 --- /dev/null +++ b/pypots/forecasting/fits/data.py @@ -0,0 +1,28 @@ +""" +Dataset class for the forecasting model FITS. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union + +from ...data.dataset import BaseDataset + + +class DatasetForFITS(BaseDataset): + """Dataset for FITS forecasting model.""" + + def __init__( + self, + data: Union[dict, str], + return_X_pred=True, + file_type: str = "hdf5", + ): + super().__init__( + data=data, + return_X_ori=False, + return_X_pred=return_X_pred, + return_y=False, + file_type=file_type, + ) diff --git a/pypots/forecasting/fits/model.py b/pypots/forecasting/fits/model.py new file mode 100644 index 00000000..433b6789 --- /dev/null +++ b/pypots/forecasting/fits/model.py @@ -0,0 +1,320 @@ +""" +The implementation of Transformer for the partially-observed time-series forecasting task. + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union, Optional + +import numpy as np +import torch +from torch.utils.data import DataLoader + +try: + import nni +except ImportError: + pass + +from .core import _FITS +from .data import DatasetForFITS +from ..base import BaseNNForecaster +from ...data.checking import key_in_data_set +from ...optim.adam import Adam +from ...optim.base import Optimizer + + +class FITS(BaseNNForecaster): + """The PyTorch implementation of the FITS forecasting model :cite:`xu2024fits`. + + Parameters + ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + n_pred_steps : + The number of steps in the forecasting time series. + + n_pred_features : + The number of features in the forecasting time series. + + cut_freq : + The cut-off frequency for the Fourier transformation. + + individual : + Whether to use individual Fourier transformation for each feature. + + apply_nonstationary_norm : + Whether to apply the non-stationary normalization to the input data. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + + optimizer : + The optimizer for model training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. + + verbose : + Whether to print out the training logs during the training process. + """ + + def __init__( + self, + n_steps: int, + n_features: int, + n_pred_steps: int, + n_pred_features: int, + cut_freq: int, + individual: bool = False, + apply_nonstationary_norm: bool = False, + batch_size: int = 32, + epochs: int = 100, + patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: Optional[str] = None, + model_saving_strategy: Optional[str] = "best", + verbose: bool = True, + ): + super().__init__( + batch_size=batch_size, + epochs=epochs, + patience=patience, + train_loss_func=train_loss_func, + val_metric_func=val_metric_func, + num_workers=num_workers, + device=device, + saving_path=saving_path, + model_saving_strategy=model_saving_strategy, + verbose=verbose, + ) + + self.n_steps = n_steps + self.n_features = n_features + self.n_pred_steps = n_pred_steps + self.n_pred_features = n_pred_features + self.cut_freq = cut_freq + self.individual = individual + self.apply_nonstationary_norm = apply_nonstationary_norm + + # set up the model + self.model = _FITS( + self.n_steps, + self.n_features, + self.n_pred_steps, + self.n_pred_features, + self.cut_freq, + self.individual, + self.apply_nonstationary_norm, + ) + self._print_model_size() + self._send_model_to_given_device() + + # set up the optimizer + self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + ( + indices, + X, + missing_mask, + X_pred, + X_pred_missing_mask, + ) = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + "X_pred": X_pred, + "X_pred_missing_mask": X_pred_missing_mask, + } + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + ( + indices, + X, + missing_mask, + ) = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + } + return inputs + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "hdf5", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForFITS( + train_set, + file_type=file_type, + ) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if not key_in_data_set("X_pred", val_set): + raise ValueError("val_set must contain 'X_pred' for model validation.") + val_set = DatasetForFITS( + val_set, + file_type=file_type, + ) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> dict: + """ + + Parameters + ---------- + test_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X' and 'y', + or a path string locating a data file. + If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : + The type of the given file if test_set is a path string. + + Returns + ------- + result_dict: dict + Prediction results in a Python Dictionary for the given samples. + It should be a dictionary including a key named 'imputation'. + + """ + + # Step 1: wrap the input data with classes Dataset and DataLoader + self.model.eval() # set the model as eval status to freeze it. + test_set = DatasetForFITS( + test_set, + return_X_pred=False, + file_type=file_type, + ) + + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + forecasting_collector = [] + + # Step 2: process the data with the model + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model(inputs) + forecasting_data = results["forecasting_data"] + forecasting_collector.append(forecasting_data) + + # Step 3: output collection and return + forecasting_data = torch.cat(forecasting_collector).cpu().detach().numpy() + result_dict = { + "forecasting": forecasting_data, # [bz, n_pred_steps, n_features] + } + return result_dict + + def forecast( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> np.ndarray: + """Forecast the future of the input with the trained model. + + Parameters + ---------- + test_set : + The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, n_pred_steps, n_features], + Forecasting results. + """ + + result_dict = self.predict(test_set, file_type=file_type) + return result_dict["forecasting"] From 050057d7e9878084711bd8753282d89f9e215fe8 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 6 Mar 2025 11:18:32 +0800 Subject: [PATCH 2/5] refactor: remove not used part in Trasformer forecasting model core; --- pypots/forecasting/transformer/core.py | 9 ++------- pypots/forecasting/transformer/model.py | 2 -- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/pypots/forecasting/transformer/core.py b/pypots/forecasting/transformer/core.py index 83d6f9c7..cda9e2fe 100644 --- a/pypots/forecasting/transformer/core.py +++ b/pypots/forecasting/transformer/core.py @@ -1,5 +1,5 @@ """ -The core wrapper assembles the submodules of Transformer imputation model +The core wrapper assembles the submodules of Transformer forecasting model and takes over the forward progress of the algorithm. """ @@ -11,7 +11,7 @@ import torch.nn as nn from ...nn.functional.error import calc_mse -from ...nn.modules.saits import SaitsLoss, SaitsEmbedding +from ...nn.modules.saits import SaitsEmbedding from ...nn.modules.transformer import TransformerEncoder, TransformerDecoder @@ -31,8 +31,6 @@ def __init__( d_ffn: int, dropout: float, attn_dropout: float, - ORT_weight: float = 1, - MIT_weight: float = 1, ): super().__init__() @@ -78,9 +76,6 @@ def __init__( ) self.output_projection = nn.Linear(d_model, n_pred_features) - # apply SAITS loss function to Transformer on the imputation task - self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] diff --git a/pypots/forecasting/transformer/model.py b/pypots/forecasting/transformer/model.py index 2b3a07fe..c658d090 100644 --- a/pypots/forecasting/transformer/model.py +++ b/pypots/forecasting/transformer/model.py @@ -185,8 +185,6 @@ def __init__( d_ffn, dropout, attn_dropout, - 1, - 1, ) self._print_model_size() self._send_model_to_given_device() From 6b9ea1e675f23ac50dcf407b7aa21df90b9fe6d5 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 6 Mar 2025 11:27:04 +0800 Subject: [PATCH 3/5] docs: update docs to add FITS forecasting model; --- README.md | 2 +- README_zh.md | 2 +- docs/index.rst | 2 +- docs/pypots.forecasting.rst | 9 +++++++++ pypots/forecasting/fits/model.py | 2 +- pypots/imputation/fits/model.py | 5 ++++- 6 files changed, 17 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 65465110..53643315 100644 --- a/README.md +++ b/README.md @@ -122,7 +122,7 @@ The paper references and links are all listed at the bottom of this file. | LLM&TSFM | Time-Series.AI [^36] | ✅ | ✅ | ✅ | ✅ | ✅ | Join waitlist | | LLM | Time-LLM[^45] | ✅ | | | | | `2024 - ICLR` | | Neural Net | TEFN🧑‍🔧[^39] | ✅ | | | | | `2024 - arXiv` | -| Neural Net | FITS🧑‍🔧[^41] | ✅ | | | | | `2024 - ICLR` | +| Neural Net | FITS🧑‍🔧[^41] | ✅ | ✅ | | | | `2024 - ICLR` | | Neural Net | TimeMixer[^37] | ✅ | | | | | `2024 - ICLR` | | Neural Net | iTransformer🧑‍🔧[^24] | ✅ | | | | | `2024 - ICLR` | | Neural Net | ModernTCN[^38] | ✅ | | | | | `2024 - ICLR` | diff --git a/README_zh.md b/README_zh.md index dcbe8d92..f57321d3 100644 --- a/README_zh.md +++ b/README_zh.md @@ -107,7 +107,7 @@ PyPOTS当前支持多变量POTS数据的插补, 预测, 分类, 聚类以及异 | LLM&TSFM | Time-Series.AI [^36] | ✅ | ✅ | ✅ | ✅ | ✅ | Join waitlist | | LLM | Time-LLM[^45] | ✅ | | | | | `2024 - ICLR` | | Neural Net | TEFN🧑‍🔧[^39] | ✅ | | | | | `2024 - arXiv` | -| Neural Net | FITS🧑‍🔧[^41] | ✅ | | | | | `2024 - ICLR` | +| Neural Net | FITS🧑‍🔧[^41] | ✅ | ✅ | | | | `2024 - ICLR` | | Neural Net | TimeMixer[^37] | ✅ | | | | | `2024 - ICLR` | | Neural Net | iTransformer🧑‍🔧[^24] | ✅ | | | | | `2024 - ICLR` | | Neural Net | ModernTCN[^38] | ✅ | | | | | `2024 - ICLR` | diff --git a/docs/index.rst b/docs/index.rst index 47c02169..b3509807 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -135,7 +135,7 @@ The paper references are all listed at the bottom of this readme file. +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ | Neural Net | TEFN🧑‍🔧 :cite:`zhan2024tefn` | ✅ | | | | | ``2024 - arXiv`` | +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ -| Neural Net | FITS🧑‍🔧 :cite:`xu2024fits` | ✅ | | | | | ``2024 - ICLR`` | +| Neural Net | FITS🧑‍🔧 :cite:`xu2024fits` | ✅ | ✅ | | | | ``2024 - ICLR`` | +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ | Neural Net | TimeMixer :cite:`wang2024timemixer` | ✅ | | | | | ``2024 - ICLR`` | +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ diff --git a/docs/pypots.forecasting.rst b/docs/pypots.forecasting.rst index bebf6120..eb406f79 100644 --- a/docs/pypots.forecasting.rst +++ b/docs/pypots.forecasting.rst @@ -10,6 +10,15 @@ pypots.forecasting.transformer :show-inheritance: :inherited-members: +pypots.forecasting.fits +------------------------------ + +.. automodule:: pypots.forecasting.fits + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + pypots.forecasting.csdi ------------------------------ diff --git a/pypots/forecasting/fits/model.py b/pypots/forecasting/fits/model.py index 433b6789..886dc034 100644 --- a/pypots/forecasting/fits/model.py +++ b/pypots/forecasting/fits/model.py @@ -1,5 +1,5 @@ """ -The implementation of Transformer for the partially-observed time-series forecasting task. +The implementation of FITS for the partially-observed time-series forecasting task. """ diff --git a/pypots/imputation/fits/model.py b/pypots/imputation/fits/model.py index b5c9dc7a..10a20635 100644 --- a/pypots/imputation/fits/model.py +++ b/pypots/imputation/fits/model.py @@ -22,7 +22,7 @@ class FITS(BaseNNImputer): - """The PyTorch implementation of the FITS model. + """The PyTorch implementation of the FITS imputation model. FITS is originally proposed by Xu et al. in :cite:`xu2024fits`. Parameters @@ -45,6 +45,9 @@ class FITS(BaseNNImputer): MIT_weight : The weight for the MIT loss, the same as SAITS. + apply_nonstationary_norm : + Whether to apply the non-stationary normalization to the input data. + batch_size : The batch size for training and evaluating the model. From d5fe85cbaa85e1e49060b36de09354b373bd8cc4 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 6 Mar 2025 11:50:44 +0800 Subject: [PATCH 4/5] test: add testing cases for FITS forecasting model; --- tests/forecasting/fits.py | 120 ++++++++++++++++++++++++++++++++++++ tests/global_test_config.py | 2 +- 2 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 tests/forecasting/fits.py diff --git a/tests/forecasting/fits.py b/tests/forecasting/fits.py new file mode 100644 index 00000000..bf383a5d --- /dev/null +++ b/tests/forecasting/fits.py @@ -0,0 +1,120 @@ +""" +Test cases for FITS forecasting model. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.forecasting import FITS +from pypots.nn.functional import calc_mse +from pypots.optim import Adam +from pypots.utils.logging import logger +from tests.global_test_config import ( + DATA, + EPOCHS, + DEVICE, + N_PRED_STEPS, + FORECASTING_TRAIN_SET, + FORECASTING_VAL_SET, + FORECASTING_TEST_SET, + FORECASTING_H5_TRAIN_SET_PATH, + FORECASTING_H5_VAL_SET_PATH, + FORECASTING_H5_TEST_SET_PATH, + RESULT_SAVING_DIR_FOR_FORECASTING, + check_tb_and_model_checkpoints_existence, +) + + +class TestFITS(unittest.TestCase): + logger.info("Running tests for an forecasting model FITS...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_FORECASTING, "FITS") + model_save_name = "saved_fits_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a FITS model + fits = FITS( + n_steps=DATA["n_steps"] - N_PRED_STEPS, + n_features=DATA["n_features"], + n_pred_steps=N_PRED_STEPS, + n_pred_features=DATA["n_features"], + individual=False, + cut_freq=3, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="forecasting-fits") + def test_0_fit(self): + self.fits.fit(FORECASTING_TRAIN_SET, FORECASTING_VAL_SET) + + @pytest.mark.xdist_group(name="forecasting-fits") + def test_1_forecasting(self): + forecasting_X = self.fits.predict(FORECASTING_TEST_SET)["forecasting"] + assert not np.isnan( + forecasting_X + ).any(), "Output has missing values in the forecasting results that should not be." + test_MSE = calc_mse( + forecasting_X, + FORECASTING_TEST_SET["X_pred"], + ~np.isnan(FORECASTING_TEST_SET["X_pred"]), + ) + logger.info(f"FITS test_MSE: {test_MSE}") + + @pytest.mark.xdist_group(name="forecasting-fits") + def test_2_parameters(self): + assert hasattr(self.fits, "model") and self.fits.model is not None + + assert hasattr(self.fits, "optimizer") and self.fits.optimizer is not None + + assert hasattr(self.fits, "best_loss") + self.assertNotEqual(self.fits.best_loss, float("inf")) + + assert hasattr(self.fits, "best_model_dict") and self.fits.best_model_dict is not None + + @pytest.mark.xdist_group(name="forecasting-fits") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists(self.saving_path), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.fits) + + # save the trained model into file, and check if the path exists + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.fits.save(saved_model_path) + + # test loading the saved model, not necessary, but need to test + self.fits.load(saved_model_path) + + @pytest.mark.xdist_group(name="forecasting-fits") + def test_4_lazy_loading(self): + self.fits.fit(FORECASTING_H5_TRAIN_SET_PATH, FORECASTING_H5_VAL_SET_PATH) + forecasting_results = self.fits.predict(FORECASTING_H5_TEST_SET_PATH) + forecasting_X = forecasting_results["forecasting"] + assert not np.isnan( + forecasting_X + ).any(), "Output has missing values in the forecasting results that should not be." + + test_MSE = calc_mse( + forecasting_X, + FORECASTING_TEST_SET["X_pred"], + ~np.isnan(FORECASTING_TEST_SET["X_pred"]), + ) + logger.info(f"Lazy-loading FITS test_MSE: {test_MSE}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/global_test_config.py b/tests/global_test_config.py index 08a7dddc..b6151180 100644 --- a/tests/global_test_config.py +++ b/tests/global_test_config.py @@ -20,7 +20,7 @@ # set the number of epochs for all model training EPOCHS = 2 # set the number of prediction steps for forecasting models -N_PRED_STEPS = 1 +N_PRED_STEPS = 2 # tensorboard and model files saving directory RESULT_SAVING_DIR = "testing_results" MODEL_SAVING_DIR = f"{RESULT_SAVING_DIR}/models" From d2aba1b0a48fd524c96f3d95ae1dd3754b146029 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 6 Mar 2025 11:55:03 +0800 Subject: [PATCH 5/5] refactor: remove imported but not used nni; --- pypots/forecasting/fits/model.py | 5 ----- pypots/forecasting/transformer/model.py | 5 ----- 2 files changed, 10 deletions(-) diff --git a/pypots/forecasting/fits/model.py b/pypots/forecasting/fits/model.py index 886dc034..a3a49759 100644 --- a/pypots/forecasting/fits/model.py +++ b/pypots/forecasting/fits/model.py @@ -12,11 +12,6 @@ import torch from torch.utils.data import DataLoader -try: - import nni -except ImportError: - pass - from .core import _FITS from .data import DatasetForFITS from ..base import BaseNNForecaster diff --git a/pypots/forecasting/transformer/model.py b/pypots/forecasting/transformer/model.py index c658d090..9139ce12 100644 --- a/pypots/forecasting/transformer/model.py +++ b/pypots/forecasting/transformer/model.py @@ -12,11 +12,6 @@ import torch from torch.utils.data import DataLoader -try: - import nni -except ImportError: - pass - from .core import _Transformer from .data import DatasetForTransformer from ..base import BaseNNForecaster