Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pypots.forecasting.TimeMixer #603

Merged
merged 9 commits into from
Mar 7, 2025
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ The paper references and links are all listed at the bottom of this file.
|:--------------|:---------------------------------------------------------------------------------------------------------------------------------|:--------:|:--------:|:--------:|:--------:|:--------:|:---------------------------------------------------|
| LLM&TSFM | <a href="https://time-series.ai"><img src="https://time-series.ai/static/figs/robot.svg" width="26px"> Time-Series.AI</a> [^36] | ✅ | ✅ | ✅ | ✅ | ✅ | <a href="https://time-series.ai">Join waitlist</a> |
| LLM | Time-LLM[^45] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | TEFN🧑‍🔧[^39] | ✅ | | | | | `2024 - arXiv` |
| Neural Net | TEFN🧑‍🔧[^39] | ✅ | | | | | `2024 - arXiv` |
| Neural Net | FITS🧑‍🔧[^41] | ✅ | ✅ | | | | `2024 - ICLR` |
| Neural Net | TimeMixer[^37] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | TimeMixer[^37] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | iTransformer🧑‍🔧[^24] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | ModernTCN[^38] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | ImputeFormer🧑‍🔧[^34] | ✅ | | | | | `2024 - KDD` |
Expand Down
4 changes: 2 additions & 2 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ PyPOTS当前支持多变量POTS数据的插补, 预测, 分类, 聚类以及异
|:--------------|:---------------------------------------------------------------------------------------------------------------------------------|:------:|:------:|:------:|:------:|:--------:|:---------------------------------------------------|
| LLM&TSFM | <a href="https://time-series.ai"><img src="https://time-series.ai/static/figs/robot.svg" width="26px"> Time-Series.AI</a> [^36] | ✅ | ✅ | ✅ | ✅ | ✅ | <a href="https://time-series.ai">Join waitlist</a> |
| LLM | Time-LLM[^45] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | TEFN🧑‍🔧[^39] | ✅ | | | | | `2024 - arXiv` |
| Neural Net | TEFN🧑‍🔧[^39] | ✅ | | | | | `2024 - arXiv` |
| Neural Net | FITS🧑‍🔧[^41] | ✅ | ✅ | | | | `2024 - ICLR` |
| Neural Net | TimeMixer[^37] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | TimeMixer[^37] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | iTransformer🧑‍🔧[^24] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | ModernTCN[^38] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | ImputeFormer🧑‍🔧[^34] | ✅ | | | | | `2024 - KDD` |
Expand Down
4 changes: 2 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ The paper references are all listed at the bottom of this readme file.
+================+===========================================================+======+======+======+======+======+=======================+
| LLM | Time-LLM :cite:`jin2024timellm` | ✅ | | | | | ``2024 - ICLR`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| Neural Net | TEFN🧑‍🔧 :cite:`zhan2024tefn` | ✅ | | | | | ``2024 - arXiv`` |
| Neural Net | TEFN🧑‍🔧 :cite:`zhan2024tefn` | ✅ | | | | | ``2024 - arXiv`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| Neural Net | FITS🧑‍🔧 :cite:`xu2024fits` | ✅ | ✅ | | | | ``2024 - ICLR`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| Neural Net | TimeMixer :cite:`wang2024timemixer` | ✅ | | | | | ``2024 - ICLR`` |
| Neural Net | TimeMixer :cite:`wang2024timemixer` | ✅ | | | | | ``2024 - ICLR`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| Neural Net | iTransformer🧑‍🔧 :cite:`liu2024itransformer` | ✅ | | | | | ``2024 - ICLR`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
Expand Down
18 changes: 18 additions & 0 deletions docs/pypots.forecasting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ pypots.forecasting.transformer
:show-inheritance:
:inherited-members:

pypots.forecasting.tefn
------------------------------

.. automodule:: pypots.forecasting.tefn
:members:
:undoc-members:
:show-inheritance:
:inherited-members:

pypots.forecasting.fits
------------------------------

Expand All @@ -19,6 +28,15 @@ pypots.forecasting.fits
:show-inheritance:
:inherited-members:

pypots.forecasting.timemixer
------------------------------

.. automodule:: pypots.forecasting.timemixer
:members:
:undoc-members:
:show-inheritance:
:inherited-members:

pypots.forecasting.csdi
------------------------------

Expand Down
7 changes: 5 additions & 2 deletions pypots/forecasting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@

from .bttf import BTTF
from .csdi import CSDI
from .transformer import Transformer
from .fits import FITS

from .tefn import TEFN
from .timemixer import TimeMixer
from .transformer import Transformer

__all__ = [
"BTTF",
"CSDI",
"Transformer",
"FITS",
"TEFN",
"TimeMixer",
]
2 changes: 1 addition & 1 deletion pypots/forecasting/bttf/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def BTTF_forecast(
"start_time should be larger than -1, "
"namely the number of the input tensor's time steps should be larger than pred_step."
)
assert start_time >= np.max(time_lags), "start_time should be >= max(time_lags)"
assert start_time >= np.max(time_lags), f"start_time {start_time} should be >= max(time_lags) {np.max(time_lags)}"
max_count = int(np.ceil(pred_step / multi_step))
tensor_hat = np.zeros((dim1, dim2, max_count * multi_step))

Expand Down
13 changes: 13 additions & 0 deletions pypots/forecasting/tefn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""

"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause


from .model import TEFN

__all__ = [
"TEFN",
]
93 changes: 93 additions & 0 deletions pypots/forecasting/tefn/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
The core wrapper assembles the submodules of TEFN forecasting model
and takes over the forward progress of the algorithm.

"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

import torch
import torch.nn as nn

from ...nn.functional import nonstationary_norm, nonstationary_denorm
from ...nn.functional.error import calc_mse
from ...nn.modules.saits import SaitsEmbedding
from ...nn.modules.tefn import BackboneTEFN


class _TEFN(nn.Module):
def __init__(
self,
n_steps: int,
n_features: int,
n_pred_steps: int,
n_pred_features: int,
n_fod: int,
apply_nonstationary_norm: bool = False,
):
super().__init__()

self.n_pred_steps = n_pred_steps
self.n_pred_features = n_pred_features
self.apply_nonstationary_norm = apply_nonstationary_norm

self.saits_embedding = SaitsEmbedding(
n_steps * 2,
n_steps + n_pred_steps,
with_pos=False,
)
self.backbone = BackboneTEFN(
n_steps,
n_features,
n_pred_steps,
n_fod,
)

# for the imputation task, the output dim is the same as input dim
self.output_projection = nn.Linear(n_features, n_pred_features)

def forward(self, inputs: dict) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

if self.training:
X_pred, X_pred_missing_mask = inputs["X_pred"], inputs["X_pred_missing_mask"]
else:
batch_size = X.shape[0]
X_pred, X_pred_missing_mask = (
torch.zeros(batch_size, self.n_pred_steps, self.n_pred_features),
torch.ones(batch_size, self.n_pred_steps, self.n_pred_features),
)

if self.apply_nonstationary_norm:
# Normalization from Non-stationary Transformer
X, means, stdev = nonstationary_norm(X, missing_mask)

# WDU: the original TEFN paper isn't proposed for imputation task. Hence the model doesn't take
# the missing mask into account, which means, in the process, the model doesn't know which part of
# the input data is missing, and this may hurt the model's imputation performance. Therefore, I apply the
# SAITS embedding method to project the concatenation of features and masks into a hidden space, as well as
# the output layers to project back from the hidden space to the original space.
enc_out = self.saits_embedding(X.permute(0, 2, 1), missing_mask.permute(0, 2, 1)).permute(0, 2, 1)

# TEFN encoder processing
enc_out = self.backbone(enc_out)
if self.apply_nonstationary_norm:
# De-Normalization from Non-stationary Transformer
enc_out = nonstationary_denorm(enc_out, means, stdev)

# project back the original data space
forecasting_result = self.output_projection(enc_out)
# the raw output has length = n_steps+n_pred_steps, we only need the last n_pred_steps
forecasting_result = forecasting_result[:, -self.n_pred_steps :]

results = {
"forecasting_data": forecasting_result,
}

# if in training mode, return results with losses
if self.training:
# `loss` is always the item for backward propagating to update the model
results["loss"] = calc_mse(X_pred, forecasting_result, X_pred_missing_mask)

return results
28 changes: 28 additions & 0 deletions pypots/forecasting/tefn/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
Dataset class for the forecasting model TEFN.
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

from typing import Union

from ...data.dataset import BaseDataset


class DatasetForTEFN(BaseDataset):
"""Dataset for TEFN forecasting model."""

def __init__(
self,
data: Union[dict, str],
return_X_pred=True,
file_type: str = "hdf5",
):
super().__init__(
data=data,
return_X_ori=False,
return_X_pred=return_X_pred,
return_y=False,
file_type=file_type,
)
Loading