diff --git a/README.md b/README.md
index 65465110..53643315 100644
--- a/README.md
+++ b/README.md
@@ -122,7 +122,7 @@ The paper references and links are all listed at the bottom of this file.
| LLM&TSFM |
Time-Series.AI [^36] | ✅ | ✅ | ✅ | ✅ | ✅ | Join waitlist |
| LLM | Time-LLM[^45] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | TEFN🧑🔧[^39] | ✅ | | | | | `2024 - arXiv` |
-| Neural Net | FITS🧑🔧[^41] | ✅ | | | | | `2024 - ICLR` |
+| Neural Net | FITS🧑🔧[^41] | ✅ | ✅ | | | | `2024 - ICLR` |
| Neural Net | TimeMixer[^37] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | iTransformer🧑🔧[^24] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | ModernTCN[^38] | ✅ | | | | | `2024 - ICLR` |
diff --git a/README_zh.md b/README_zh.md
index dcbe8d92..f57321d3 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -107,7 +107,7 @@ PyPOTS当前支持多变量POTS数据的插补, 预测, 分类, 聚类以及异
| LLM&TSFM |
Time-Series.AI [^36] | ✅ | ✅ | ✅ | ✅ | ✅ | Join waitlist |
| LLM | Time-LLM[^45] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | TEFN🧑🔧[^39] | ✅ | | | | | `2024 - arXiv` |
-| Neural Net | FITS🧑🔧[^41] | ✅ | | | | | `2024 - ICLR` |
+| Neural Net | FITS🧑🔧[^41] | ✅ | ✅ | | | | `2024 - ICLR` |
| Neural Net | TimeMixer[^37] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | iTransformer🧑🔧[^24] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | ModernTCN[^38] | ✅ | | | | | `2024 - ICLR` |
diff --git a/docs/index.rst b/docs/index.rst
index 47c02169..b3509807 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -135,7 +135,7 @@ The paper references are all listed at the bottom of this readme file.
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| Neural Net | TEFN🧑🔧 :cite:`zhan2024tefn` | ✅ | | | | | ``2024 - arXiv`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
-| Neural Net | FITS🧑🔧 :cite:`xu2024fits` | ✅ | | | | | ``2024 - ICLR`` |
+| Neural Net | FITS🧑🔧 :cite:`xu2024fits` | ✅ | ✅ | | | | ``2024 - ICLR`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| Neural Net | TimeMixer :cite:`wang2024timemixer` | ✅ | | | | | ``2024 - ICLR`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
diff --git a/docs/pypots.forecasting.rst b/docs/pypots.forecasting.rst
index bebf6120..eb406f79 100644
--- a/docs/pypots.forecasting.rst
+++ b/docs/pypots.forecasting.rst
@@ -10,6 +10,15 @@ pypots.forecasting.transformer
:show-inheritance:
:inherited-members:
+pypots.forecasting.fits
+------------------------------
+
+.. automodule:: pypots.forecasting.fits
+ :members:
+ :undoc-members:
+ :show-inheritance:
+ :inherited-members:
+
pypots.forecasting.csdi
------------------------------
diff --git a/pypots/forecasting/__init__.py b/pypots/forecasting/__init__.py
index 8466faca..f7aea979 100644
--- a/pypots/forecasting/__init__.py
+++ b/pypots/forecasting/__init__.py
@@ -8,9 +8,12 @@
from .bttf import BTTF
from .csdi import CSDI
from .transformer import Transformer
+from .fits import FITS
+
__all__ = [
"BTTF",
"CSDI",
"Transformer",
+ "FITS",
]
diff --git a/pypots/forecasting/fits/__init__.py b/pypots/forecasting/fits/__init__.py
new file mode 100644
index 00000000..bee750c3
--- /dev/null
+++ b/pypots/forecasting/fits/__init__.py
@@ -0,0 +1,13 @@
+"""
+
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+
+from .model import FITS
+
+__all__ = [
+ "FITS",
+]
diff --git a/pypots/forecasting/fits/core.py b/pypots/forecasting/fits/core.py
new file mode 100644
index 00000000..f4860030
--- /dev/null
+++ b/pypots/forecasting/fits/core.py
@@ -0,0 +1,95 @@
+"""
+The core wrapper assembles the submodules of FITS forecasting model
+and takes over the forward progress of the algorithm.
+
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+import torch
+import torch.nn as nn
+
+from ...nn.functional import nonstationary_norm, nonstationary_denorm
+from ...nn.functional.error import calc_mse
+from ...nn.modules.fits import BackboneFITS
+from ...nn.modules.saits import SaitsEmbedding
+
+
+class _FITS(nn.Module):
+ def __init__(
+ self,
+ n_steps: int,
+ n_features: int,
+ n_pred_steps: int,
+ n_pred_features: int,
+ cut_freq: int,
+ individual: bool,
+ apply_nonstationary_norm: bool = False,
+ ):
+ super().__init__()
+
+ self.n_pred_steps = n_pred_steps
+ self.n_pred_features = n_pred_features
+ self.apply_nonstationary_norm = apply_nonstationary_norm
+
+ self.saits_embedding = SaitsEmbedding(
+ n_features * 2,
+ n_features,
+ with_pos=False,
+ )
+ self.backbone = BackboneFITS(
+ n_steps,
+ n_features,
+ n_pred_steps,
+ cut_freq,
+ individual,
+ )
+
+ # for the imputation task, the output dim is the same as input dim
+ self.output_projection = nn.Linear(n_features, n_pred_features)
+
+ def forward(self, inputs: dict) -> dict:
+ X, missing_mask = inputs["X"], inputs["missing_mask"]
+
+ if self.training:
+ X_pred, X_pred_missing_mask = inputs["X_pred"], inputs["X_pred_missing_mask"]
+ else:
+ batch_size = X.shape[0]
+ X_pred, X_pred_missing_mask = (
+ torch.zeros(batch_size, self.n_pred_steps, self.n_pred_features),
+ torch.ones(batch_size, self.n_pred_steps, self.n_pred_features),
+ )
+
+ if self.apply_nonstationary_norm:
+ # Normalization from Non-stationary Transformer
+ X, means, stdev = nonstationary_norm(X, missing_mask)
+
+ # WDU: the original FITS paper isn't proposed for imputation task. Hence the model doesn't take
+ # the missing mask into account, which means, in the process, the model doesn't know which part of
+ # the input data is missing, and this may hurt the model's imputation performance. Therefore, I apply the
+ # SAITS embedding method to project the concatenation of features and masks into a hidden space, as well as
+ # the output layers to project back from the hidden space to the original space.
+ enc_out = self.saits_embedding(X, missing_mask)
+
+ # FITS encoder processing
+ enc_out = self.backbone(enc_out)
+ if self.apply_nonstationary_norm:
+ # De-Normalization from Non-stationary Transformer
+ enc_out = nonstationary_denorm(enc_out, means, stdev)
+
+ # project back the original data space
+ forecasting_result = self.output_projection(enc_out)
+ # the raw output has length = n_steps+n_pred_steps, we only need the last n_pred_steps
+ forecasting_result = forecasting_result[:, -self.n_pred_steps :]
+
+ results = {
+ "forecasting_data": forecasting_result,
+ }
+
+ # if in training mode, return results with losses
+ if self.training:
+ # `loss` is always the item for backward propagating to update the model
+ results["loss"] = calc_mse(X_pred, forecasting_result, X_pred_missing_mask)
+
+ return results
diff --git a/pypots/forecasting/fits/data.py b/pypots/forecasting/fits/data.py
new file mode 100644
index 00000000..2ffe6b63
--- /dev/null
+++ b/pypots/forecasting/fits/data.py
@@ -0,0 +1,28 @@
+"""
+Dataset class for the forecasting model FITS.
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+from typing import Union
+
+from ...data.dataset import BaseDataset
+
+
+class DatasetForFITS(BaseDataset):
+ """Dataset for FITS forecasting model."""
+
+ def __init__(
+ self,
+ data: Union[dict, str],
+ return_X_pred=True,
+ file_type: str = "hdf5",
+ ):
+ super().__init__(
+ data=data,
+ return_X_ori=False,
+ return_X_pred=return_X_pred,
+ return_y=False,
+ file_type=file_type,
+ )
diff --git a/pypots/forecasting/fits/model.py b/pypots/forecasting/fits/model.py
new file mode 100644
index 00000000..a3a49759
--- /dev/null
+++ b/pypots/forecasting/fits/model.py
@@ -0,0 +1,315 @@
+"""
+The implementation of FITS for the partially-observed time-series forecasting task.
+
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+from typing import Union, Optional
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+
+from .core import _FITS
+from .data import DatasetForFITS
+from ..base import BaseNNForecaster
+from ...data.checking import key_in_data_set
+from ...optim.adam import Adam
+from ...optim.base import Optimizer
+
+
+class FITS(BaseNNForecaster):
+ """The PyTorch implementation of the FITS forecasting model :cite:`xu2024fits`.
+
+ Parameters
+ ----------
+ n_steps :
+ The number of time steps in the time-series data sample.
+
+ n_features :
+ The number of features in the time-series data sample.
+
+ n_pred_steps :
+ The number of steps in the forecasting time series.
+
+ n_pred_features :
+ The number of features in the forecasting time series.
+
+ cut_freq :
+ The cut-off frequency for the Fourier transformation.
+
+ individual :
+ Whether to use individual Fourier transformation for each feature.
+
+ apply_nonstationary_norm :
+ Whether to apply the non-stationary normalization to the input data.
+
+ batch_size :
+ The batch size for training and evaluating the model.
+
+ epochs :
+ The number of epochs for training the model.
+
+ patience :
+ The patience for the early-stopping mechanism. Given a positive integer, the training process will be
+ stopped when the model does not perform better after that number of epochs.
+ Leaving it default as None will disable the early-stopping.
+
+ train_loss_func:
+ The customized loss function designed by users for training the model.
+ If not given, will use the default loss as claimed in the original paper.
+
+ val_metric_func:
+ The customized metric function designed by users for validating the model.
+ If not given, will use the default MSE metric.
+
+ optimizer :
+ The optimizer for model training.
+ If not given, will use a default Adam optimizer.
+
+ num_workers :
+ The number of subprocesses to use for data loading.
+ `0` means data loading will be in the main process, i.e. there won't be subprocesses.
+
+ device :
+ The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them.
+ If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple),
+ then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models.
+ If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the
+ model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices).
+ Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future.
+
+ saving_path :
+ The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during
+ training into a tensorboard file). Will not save if not given.
+
+ model_saving_strategy :
+ The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"].
+ No model will be saved when it is set as None.
+ The "best" strategy will only automatically save the best model after the training finished.
+ The "better" strategy will automatically save the model during training whenever the model performs
+ better than in previous epochs.
+ The "all" strategy will save every model after each epoch training.
+
+ verbose :
+ Whether to print out the training logs during the training process.
+ """
+
+ def __init__(
+ self,
+ n_steps: int,
+ n_features: int,
+ n_pred_steps: int,
+ n_pred_features: int,
+ cut_freq: int,
+ individual: bool = False,
+ apply_nonstationary_norm: bool = False,
+ batch_size: int = 32,
+ epochs: int = 100,
+ patience: Optional[int] = None,
+ train_loss_func: Optional[dict] = None,
+ val_metric_func: Optional[dict] = None,
+ optimizer: Optional[Optimizer] = Adam(),
+ num_workers: int = 0,
+ device: Optional[Union[str, torch.device, list]] = None,
+ saving_path: Optional[str] = None,
+ model_saving_strategy: Optional[str] = "best",
+ verbose: bool = True,
+ ):
+ super().__init__(
+ batch_size=batch_size,
+ epochs=epochs,
+ patience=patience,
+ train_loss_func=train_loss_func,
+ val_metric_func=val_metric_func,
+ num_workers=num_workers,
+ device=device,
+ saving_path=saving_path,
+ model_saving_strategy=model_saving_strategy,
+ verbose=verbose,
+ )
+
+ self.n_steps = n_steps
+ self.n_features = n_features
+ self.n_pred_steps = n_pred_steps
+ self.n_pred_features = n_pred_features
+ self.cut_freq = cut_freq
+ self.individual = individual
+ self.apply_nonstationary_norm = apply_nonstationary_norm
+
+ # set up the model
+ self.model = _FITS(
+ self.n_steps,
+ self.n_features,
+ self.n_pred_steps,
+ self.n_pred_features,
+ self.cut_freq,
+ self.individual,
+ self.apply_nonstationary_norm,
+ )
+ self._print_model_size()
+ self._send_model_to_given_device()
+
+ # set up the optimizer
+ self.optimizer = optimizer
+ self.optimizer.init_optimizer(self.model.parameters())
+
+ def _assemble_input_for_training(self, data: list) -> dict:
+ (
+ indices,
+ X,
+ missing_mask,
+ X_pred,
+ X_pred_missing_mask,
+ ) = self._send_data_to_given_device(data)
+
+ inputs = {
+ "X": X,
+ "missing_mask": missing_mask,
+ "X_pred": X_pred,
+ "X_pred_missing_mask": X_pred_missing_mask,
+ }
+ return inputs
+
+ def _assemble_input_for_validating(self, data: list) -> dict:
+ return self._assemble_input_for_training(data)
+
+ def _assemble_input_for_testing(self, data: list) -> dict:
+ (
+ indices,
+ X,
+ missing_mask,
+ ) = self._send_data_to_given_device(data)
+
+ inputs = {
+ "X": X,
+ "missing_mask": missing_mask,
+ }
+ return inputs
+
+ def fit(
+ self,
+ train_set: Union[dict, str],
+ val_set: Optional[Union[dict, str]] = None,
+ file_type: str = "hdf5",
+ ) -> None:
+ # Step 1: wrap the input data with classes Dataset and DataLoader
+ training_set = DatasetForFITS(
+ train_set,
+ file_type=file_type,
+ )
+ training_loader = DataLoader(
+ training_set,
+ batch_size=self.batch_size,
+ shuffle=True,
+ num_workers=self.num_workers,
+ )
+ val_loader = None
+ if val_set is not None:
+ if not key_in_data_set("X_pred", val_set):
+ raise ValueError("val_set must contain 'X_pred' for model validation.")
+ val_set = DatasetForFITS(
+ val_set,
+ file_type=file_type,
+ )
+ val_loader = DataLoader(
+ val_set,
+ batch_size=self.batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ )
+
+ # Step 2: train the model and freeze it
+ self._train_model(training_loader, val_loader)
+ self.model.load_state_dict(self.best_model_dict)
+ self.model.eval() # set the model as eval status to freeze it.
+
+ # Step 3: save the model if necessary
+ self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best")
+
+ def predict(
+ self,
+ test_set: Union[dict, str],
+ file_type: str = "hdf5",
+ ) -> dict:
+ """
+
+ Parameters
+ ----------
+ test_set : dict or str
+ The dataset for model validating, should be a dictionary including keys as 'X' and 'y',
+ or a path string locating a data file.
+ If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
+ which is time-series data for validating, can contain missing values, and y should be array-like of shape
+ [n_samples], which is classification labels of X.
+ If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
+ key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
+
+ file_type :
+ The type of the given file if test_set is a path string.
+
+ Returns
+ -------
+ result_dict: dict
+ Prediction results in a Python Dictionary for the given samples.
+ It should be a dictionary including a key named 'imputation'.
+
+ """
+
+ # Step 1: wrap the input data with classes Dataset and DataLoader
+ self.model.eval() # set the model as eval status to freeze it.
+ test_set = DatasetForFITS(
+ test_set,
+ return_X_pred=False,
+ file_type=file_type,
+ )
+
+ test_loader = DataLoader(
+ test_set,
+ batch_size=self.batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ )
+ forecasting_collector = []
+
+ # Step 2: process the data with the model
+ with torch.no_grad():
+ for idx, data in enumerate(test_loader):
+ inputs = self._assemble_input_for_testing(data)
+ results = self.model(inputs)
+ forecasting_data = results["forecasting_data"]
+ forecasting_collector.append(forecasting_data)
+
+ # Step 3: output collection and return
+ forecasting_data = torch.cat(forecasting_collector).cpu().detach().numpy()
+ result_dict = {
+ "forecasting": forecasting_data, # [bz, n_pred_steps, n_features]
+ }
+ return result_dict
+
+ def forecast(
+ self,
+ test_set: Union[dict, str],
+ file_type: str = "hdf5",
+ ) -> np.ndarray:
+ """Forecast the future of the input with the trained model.
+
+ Parameters
+ ----------
+ test_set :
+ The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps),
+ n_features], or a path string locating a data file, e.g. h5 file.
+
+ file_type :
+ The type of the given file if X is a path string.
+
+ Returns
+ -------
+ array-like, shape [n_samples, n_pred_steps, n_features],
+ Forecasting results.
+ """
+
+ result_dict = self.predict(test_set, file_type=file_type)
+ return result_dict["forecasting"]
diff --git a/pypots/forecasting/transformer/core.py b/pypots/forecasting/transformer/core.py
index 83d6f9c7..cda9e2fe 100644
--- a/pypots/forecasting/transformer/core.py
+++ b/pypots/forecasting/transformer/core.py
@@ -1,5 +1,5 @@
"""
-The core wrapper assembles the submodules of Transformer imputation model
+The core wrapper assembles the submodules of Transformer forecasting model
and takes over the forward progress of the algorithm.
"""
@@ -11,7 +11,7 @@
import torch.nn as nn
from ...nn.functional.error import calc_mse
-from ...nn.modules.saits import SaitsLoss, SaitsEmbedding
+from ...nn.modules.saits import SaitsEmbedding
from ...nn.modules.transformer import TransformerEncoder, TransformerDecoder
@@ -31,8 +31,6 @@ def __init__(
d_ffn: int,
dropout: float,
attn_dropout: float,
- ORT_weight: float = 1,
- MIT_weight: float = 1,
):
super().__init__()
@@ -78,9 +76,6 @@ def __init__(
)
self.output_projection = nn.Linear(d_model, n_pred_features)
- # apply SAITS loss function to Transformer on the imputation task
- self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)
-
def forward(self, inputs: dict) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]
diff --git a/pypots/forecasting/transformer/model.py b/pypots/forecasting/transformer/model.py
index 2b3a07fe..9139ce12 100644
--- a/pypots/forecasting/transformer/model.py
+++ b/pypots/forecasting/transformer/model.py
@@ -12,11 +12,6 @@
import torch
from torch.utils.data import DataLoader
-try:
- import nni
-except ImportError:
- pass
-
from .core import _Transformer
from .data import DatasetForTransformer
from ..base import BaseNNForecaster
@@ -185,8 +180,6 @@ def __init__(
d_ffn,
dropout,
attn_dropout,
- 1,
- 1,
)
self._print_model_size()
self._send_model_to_given_device()
diff --git a/pypots/imputation/fits/model.py b/pypots/imputation/fits/model.py
index b5c9dc7a..10a20635 100644
--- a/pypots/imputation/fits/model.py
+++ b/pypots/imputation/fits/model.py
@@ -22,7 +22,7 @@
class FITS(BaseNNImputer):
- """The PyTorch implementation of the FITS model.
+ """The PyTorch implementation of the FITS imputation model.
FITS is originally proposed by Xu et al. in :cite:`xu2024fits`.
Parameters
@@ -45,6 +45,9 @@ class FITS(BaseNNImputer):
MIT_weight :
The weight for the MIT loss, the same as SAITS.
+ apply_nonstationary_norm :
+ Whether to apply the non-stationary normalization to the input data.
+
batch_size :
The batch size for training and evaluating the model.
diff --git a/tests/forecasting/fits.py b/tests/forecasting/fits.py
new file mode 100644
index 00000000..bf383a5d
--- /dev/null
+++ b/tests/forecasting/fits.py
@@ -0,0 +1,120 @@
+"""
+Test cases for FITS forecasting model.
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+
+import os.path
+import unittest
+
+import numpy as np
+import pytest
+
+from pypots.forecasting import FITS
+from pypots.nn.functional import calc_mse
+from pypots.optim import Adam
+from pypots.utils.logging import logger
+from tests.global_test_config import (
+ DATA,
+ EPOCHS,
+ DEVICE,
+ N_PRED_STEPS,
+ FORECASTING_TRAIN_SET,
+ FORECASTING_VAL_SET,
+ FORECASTING_TEST_SET,
+ FORECASTING_H5_TRAIN_SET_PATH,
+ FORECASTING_H5_VAL_SET_PATH,
+ FORECASTING_H5_TEST_SET_PATH,
+ RESULT_SAVING_DIR_FOR_FORECASTING,
+ check_tb_and_model_checkpoints_existence,
+)
+
+
+class TestFITS(unittest.TestCase):
+ logger.info("Running tests for an forecasting model FITS...")
+
+ # set the log and model saving path
+ saving_path = os.path.join(RESULT_SAVING_DIR_FOR_FORECASTING, "FITS")
+ model_save_name = "saved_fits_model.pypots"
+
+ # initialize an Adam optimizer
+ optimizer = Adam(lr=0.001, weight_decay=1e-5)
+
+ # initialize a FITS model
+ fits = FITS(
+ n_steps=DATA["n_steps"] - N_PRED_STEPS,
+ n_features=DATA["n_features"],
+ n_pred_steps=N_PRED_STEPS,
+ n_pred_features=DATA["n_features"],
+ individual=False,
+ cut_freq=3,
+ epochs=EPOCHS,
+ saving_path=saving_path,
+ optimizer=optimizer,
+ device=DEVICE,
+ )
+
+ @pytest.mark.xdist_group(name="forecasting-fits")
+ def test_0_fit(self):
+ self.fits.fit(FORECASTING_TRAIN_SET, FORECASTING_VAL_SET)
+
+ @pytest.mark.xdist_group(name="forecasting-fits")
+ def test_1_forecasting(self):
+ forecasting_X = self.fits.predict(FORECASTING_TEST_SET)["forecasting"]
+ assert not np.isnan(
+ forecasting_X
+ ).any(), "Output has missing values in the forecasting results that should not be."
+ test_MSE = calc_mse(
+ forecasting_X,
+ FORECASTING_TEST_SET["X_pred"],
+ ~np.isnan(FORECASTING_TEST_SET["X_pred"]),
+ )
+ logger.info(f"FITS test_MSE: {test_MSE}")
+
+ @pytest.mark.xdist_group(name="forecasting-fits")
+ def test_2_parameters(self):
+ assert hasattr(self.fits, "model") and self.fits.model is not None
+
+ assert hasattr(self.fits, "optimizer") and self.fits.optimizer is not None
+
+ assert hasattr(self.fits, "best_loss")
+ self.assertNotEqual(self.fits.best_loss, float("inf"))
+
+ assert hasattr(self.fits, "best_model_dict") and self.fits.best_model_dict is not None
+
+ @pytest.mark.xdist_group(name="forecasting-fits")
+ def test_3_saving_path(self):
+ # whether the root saving dir exists, which should be created by save_log_into_tb_file
+ assert os.path.exists(self.saving_path), f"file {self.saving_path} does not exist"
+
+ # check if the tensorboard file and model checkpoints exist
+ check_tb_and_model_checkpoints_existence(self.fits)
+
+ # save the trained model into file, and check if the path exists
+ saved_model_path = os.path.join(self.saving_path, self.model_save_name)
+ self.fits.save(saved_model_path)
+
+ # test loading the saved model, not necessary, but need to test
+ self.fits.load(saved_model_path)
+
+ @pytest.mark.xdist_group(name="forecasting-fits")
+ def test_4_lazy_loading(self):
+ self.fits.fit(FORECASTING_H5_TRAIN_SET_PATH, FORECASTING_H5_VAL_SET_PATH)
+ forecasting_results = self.fits.predict(FORECASTING_H5_TEST_SET_PATH)
+ forecasting_X = forecasting_results["forecasting"]
+ assert not np.isnan(
+ forecasting_X
+ ).any(), "Output has missing values in the forecasting results that should not be."
+
+ test_MSE = calc_mse(
+ forecasting_X,
+ FORECASTING_TEST_SET["X_pred"],
+ ~np.isnan(FORECASTING_TEST_SET["X_pred"]),
+ )
+ logger.info(f"Lazy-loading FITS test_MSE: {test_MSE}")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/global_test_config.py b/tests/global_test_config.py
index 08a7dddc..b6151180 100644
--- a/tests/global_test_config.py
+++ b/tests/global_test_config.py
@@ -20,7 +20,7 @@
# set the number of epochs for all model training
EPOCHS = 2
# set the number of prediction steps for forecasting models
-N_PRED_STEPS = 1
+N_PRED_STEPS = 2
# tensorboard and model files saving directory
RESULT_SAVING_DIR = "testing_results"
MODEL_SAVING_DIR = f"{RESULT_SAVING_DIR}/models"