Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pypots.forecasting.FITS #600

Merged
merged 5 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ The paper references and links are all listed at the bottom of this file.
| LLM&TSFM | <a href="https://time-series.ai"><img src="https://time-series.ai/static/figs/robot.svg" width="26px"> Time-Series.AI</a> [^36] | ✅ | ✅ | ✅ | ✅ | ✅ | <a href="https://time-series.ai">Join waitlist</a> |
| LLM | Time-LLM[^45] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | TEFN🧑‍🔧[^39] | ✅ | | | | | `2024 - arXiv` |
| Neural Net | FITS🧑‍🔧[^41] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | FITS🧑‍🔧[^41] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | TimeMixer[^37] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | iTransformer🧑‍🔧[^24] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | ModernTCN[^38] | ✅ | | | | | `2024 - ICLR` |
Expand Down
2 changes: 1 addition & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ PyPOTS当前支持多变量POTS数据的插补, 预测, 分类, 聚类以及异
| LLM&TSFM | <a href="https://time-series.ai"><img src="https://time-series.ai/static/figs/robot.svg" width="26px"> Time-Series.AI</a> [^36] | ✅ | ✅ | ✅ | ✅ | ✅ | <a href="https://time-series.ai">Join waitlist</a> |
| LLM | Time-LLM[^45] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | TEFN🧑‍🔧[^39] | ✅ | | | | | `2024 - arXiv` |
| Neural Net | FITS🧑‍🔧[^41] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | FITS🧑‍🔧[^41] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | TimeMixer[^37] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | iTransformer🧑‍🔧[^24] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | ModernTCN[^38] | ✅ | | | | | `2024 - ICLR` |
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ The paper references are all listed at the bottom of this readme file.
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| Neural Net | TEFN🧑‍🔧 :cite:`zhan2024tefn` | ✅ | | | | | ``2024 - arXiv`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| Neural Net | FITS🧑‍🔧 :cite:`xu2024fits` | ✅ | | | | | ``2024 - ICLR`` |
| Neural Net | FITS🧑‍🔧 :cite:`xu2024fits` | ✅ | | | | | ``2024 - ICLR`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| Neural Net | TimeMixer :cite:`wang2024timemixer` | ✅ | | | | | ``2024 - ICLR`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
Expand Down
9 changes: 9 additions & 0 deletions docs/pypots.forecasting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ pypots.forecasting.transformer
:show-inheritance:
:inherited-members:

pypots.forecasting.fits
------------------------------

.. automodule:: pypots.forecasting.fits
:members:
:undoc-members:
:show-inheritance:
:inherited-members:

pypots.forecasting.csdi
------------------------------

Expand Down
3 changes: 3 additions & 0 deletions pypots/forecasting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
from .bttf import BTTF
from .csdi import CSDI
from .transformer import Transformer
from .fits import FITS


__all__ = [
"BTTF",
"CSDI",
"Transformer",
"FITS",
]
13 changes: 13 additions & 0 deletions pypots/forecasting/fits/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""

"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause


from .model import FITS

__all__ = [
"FITS",
]
95 changes: 95 additions & 0 deletions pypots/forecasting/fits/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
The core wrapper assembles the submodules of FITS forecasting model
and takes over the forward progress of the algorithm.

"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

import torch
import torch.nn as nn

from ...nn.functional import nonstationary_norm, nonstationary_denorm
from ...nn.functional.error import calc_mse
from ...nn.modules.fits import BackboneFITS
from ...nn.modules.saits import SaitsEmbedding


class _FITS(nn.Module):
def __init__(
self,
n_steps: int,
n_features: int,
n_pred_steps: int,
n_pred_features: int,
cut_freq: int,
individual: bool,
apply_nonstationary_norm: bool = False,
):
super().__init__()

self.n_pred_steps = n_pred_steps
self.n_pred_features = n_pred_features
self.apply_nonstationary_norm = apply_nonstationary_norm

self.saits_embedding = SaitsEmbedding(
n_features * 2,
n_features,
with_pos=False,
)
self.backbone = BackboneFITS(
n_steps,
n_features,
n_pred_steps,
cut_freq,
individual,
)

# for the imputation task, the output dim is the same as input dim
self.output_projection = nn.Linear(n_features, n_pred_features)

def forward(self, inputs: dict) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

if self.training:
X_pred, X_pred_missing_mask = inputs["X_pred"], inputs["X_pred_missing_mask"]
else:
batch_size = X.shape[0]
X_pred, X_pred_missing_mask = (
torch.zeros(batch_size, self.n_pred_steps, self.n_pred_features),
torch.ones(batch_size, self.n_pred_steps, self.n_pred_features),
)

if self.apply_nonstationary_norm:
# Normalization from Non-stationary Transformer
X, means, stdev = nonstationary_norm(X, missing_mask)

# WDU: the original FITS paper isn't proposed for imputation task. Hence the model doesn't take
# the missing mask into account, which means, in the process, the model doesn't know which part of
# the input data is missing, and this may hurt the model's imputation performance. Therefore, I apply the
# SAITS embedding method to project the concatenation of features and masks into a hidden space, as well as
# the output layers to project back from the hidden space to the original space.
enc_out = self.saits_embedding(X, missing_mask)

# FITS encoder processing
enc_out = self.backbone(enc_out)
if self.apply_nonstationary_norm:
# De-Normalization from Non-stationary Transformer
enc_out = nonstationary_denorm(enc_out, means, stdev)

# project back the original data space
forecasting_result = self.output_projection(enc_out)
# the raw output has length = n_steps+n_pred_steps, we only need the last n_pred_steps
forecasting_result = forecasting_result[:, -self.n_pred_steps :]

results = {
"forecasting_data": forecasting_result,
}

# if in training mode, return results with losses
if self.training:
# `loss` is always the item for backward propagating to update the model
results["loss"] = calc_mse(X_pred, forecasting_result, X_pred_missing_mask)

return results
28 changes: 28 additions & 0 deletions pypots/forecasting/fits/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
Dataset class for the forecasting model FITS.
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

from typing import Union

from ...data.dataset import BaseDataset


class DatasetForFITS(BaseDataset):
"""Dataset for FITS forecasting model."""

def __init__(
self,
data: Union[dict, str],
return_X_pred=True,
file_type: str = "hdf5",
):
super().__init__(
data=data,
return_X_ori=False,
return_X_pred=return_X_pred,
return_y=False,
file_type=file_type,
)
Loading
Loading