Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pypots.forecasting.Transformer #597

Merged
merged 6 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ The paper references and links are all listed at the bottom of this file.
| Neural Net | BRITS[^3] | ✅ | | ✅ | | | `2018 - NeurIPS` |
| Neural Net | GRU-D[^4] | ✅ | | ✅ | | | `2018 - Sci. Rep.` |
| Neural Net | TCN🧑‍🔧[^35] | ✅ | | | | | `2018 - arXiv` |
| Neural Net | Transformer🧑‍🔧[^2] | ✅ | | | | | `2017 - NeurIPS` |
| Neural Net | Transformer🧑‍🔧[^2] | ✅ | | | | | `2017 - NeurIPS` |
| MF | TRMF[^44] | ✅ | | | | | `2016 - NeurIPS` |
| Naive | Lerp[^40] | ✅ | | | | | |
| Naive | LOCF/NOCB | ✅ | | | | | |
Expand Down
2 changes: 1 addition & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ PyPOTS当前支持多变量POTS数据的插补, 预测, 分类, 聚类以及异
| Neural Net | BRITS[^3] | ✅ | | ✅ | | | `2018 - NeurIPS` |
| Neural Net | GRU-D[^4] | ✅ | | ✅ | | | `2018 - Sci. Rep.` |
| Neural Net | TCN🧑‍🔧[^35] | ✅ | | | | | `2018 - arXiv` |
| Neural Net | Transformer🧑‍🔧[^2] | ✅ | | | | | `2017 - NeurIPS` |
| Neural Net | Transformer🧑‍🔧[^2] | ✅ | | | | | `2017 - NeurIPS` |
| MF | TRMF[^44] | ✅ | | | | | `2016 - NeurIPS` |
| Naive | Lerp[^40] | ✅ | | | | | |
| Naive | LOCF/NOCB | ✅ | | | | | |
Expand Down
4 changes: 2 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ The paper references are all listed at the bottom of this readme file.
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| Neural Net | SAITS :cite:`du2023SAITS` | ✅ | | | | | ``2023 - ESWA`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| LLM | GPT4TS :cite:`zhou2024gpt4ts` | ✅ | | | | | ``2023 - NeurIPS`` |
| LLM | GPT4TS :cite:`zhou2023gpt4ts` | ✅ | | | | | ``2023 - NeurIPS`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| Neural Net | FreTS🧑‍🔧 :cite:`yi2023frets` | ✅ | | | | | ``2023 - NeurIPS`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
Expand Down Expand Up @@ -213,7 +213,7 @@ The paper references are all listed at the bottom of this readme file.
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| Neural Net | TCN🧑‍🔧 :cite:`bai2018tcn` | ✅ | | | | | ``2018 - arXiv`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| Neural Net | Transformer🧑‍🔧 :cite:`vaswani2017Transformer` | ✅ | | | | | ``2017 - NeurIPS`` |
| Neural Net | Transformer🧑‍🔧 :cite:`vaswani2017Transformer` | ✅ | | | | | ``2017 - NeurIPS`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| MF | TRMF :cite:`yu2016trmf` | ✅ | | | | | ``2016 - NeurIPS`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
Expand Down
18 changes: 18 additions & 0 deletions docs/pypots.forecasting.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,24 @@
pypots.forecasting package
==========================

pypots.forecasting.transformer
------------------------------

.. automodule:: pypots.forecasting.transformer
:members:
:undoc-members:
:show-inheritance:
:inherited-members:

pypots.forecasting.csdi
------------------------------

.. automodule:: pypots.forecasting.csdi
:members:
:undoc-members:
:show-inheritance:
:inherited-members:

pypots.forecasting.bttf
------------------------------

Expand Down
4 changes: 2 additions & 2 deletions pypots/data/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ def _fetch_data_from_array(self, idx: int) -> Iterable:

if self.return_X_pred:
X_pred = self.X_pred[idx]
pred_missing_mask = self.X_pred[idx]
sample.extend([X_pred, pred_missing_mask])
X_pred_missing_mask = self.X_pred_missing_mask[idx]
sample.extend([X_pred, X_pred_missing_mask])

if self.return_y:
sample.append(self.y[idx].to(torch.long))
Expand Down
2 changes: 2 additions & 0 deletions pypots/forecasting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

from .bttf import BTTF
from .csdi import CSDI
from .transformer import Transformer

__all__ = [
"BTTF",
"CSDI",
"Transformer",
]
13 changes: 11 additions & 2 deletions pypots/forecasting/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from ..base import BaseModel, BaseNNModel
from ..nn.functional import calc_mse, autocast
from ..nn.modules.loss import MSE
from ..utils.logging import logger

try:
Expand Down Expand Up @@ -233,6 +234,14 @@ def __init__(
verbose=verbose,
)

# set default training loss function and validation metric function if not given
if train_loss_func is None:
self.train_loss_func = MSE()
self.train_loss_func_name = self.train_loss_func.__class__.__name__
if val_metric_func is None:
self.val_metric_func = MSE()
self.val_metric_func_name = self.val_metric_func.__class__.__name__

@abstractmethod
def _assemble_input_for_training(self, data: list) -> dict:
"""Assemble the given data into a dictionary for training input.
Expand Down Expand Up @@ -344,8 +353,8 @@ def _train_model(
forecasting_mse = (
calc_mse(
results["forecasting_data"],
inputs["X_ori"],
inputs["indicating_mask"],
inputs["X_pred"],
inputs["X_pred_missing_mask"],
)
.sum()
.detach()
Expand Down
13 changes: 13 additions & 0 deletions pypots/forecasting/transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""

"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause


from .model import Transformer

__all__ = [
"Transformer",
]
116 changes: 116 additions & 0 deletions pypots/forecasting/transformer/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""
The core wrapper assembles the submodules of Transformer imputation model
and takes over the forward progress of the algorithm.

"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

import torch
import torch.nn as nn

from ...nn.functional.error import calc_mse
from ...nn.modules.saits import SaitsLoss, SaitsEmbedding
from ...nn.modules.transformer import TransformerEncoder, TransformerDecoder


class _Transformer(nn.Module):
def __init__(
self,
n_steps: int,
n_features: int,
n_pred_steps: int,
n_pred_features: int,
n_encoder_layers: int,
n_decoder_layers: int,
d_model: int,
n_heads: int,
d_k: int,
d_v: int,
d_ffn: int,
dropout: float,
attn_dropout: float,
ORT_weight: float = 1,
MIT_weight: float = 1,
):
super().__init__()

self.n_steps = n_steps
self.n_features = n_features
self.n_pred_steps = n_pred_steps
self.n_pred_features = n_pred_features

self.encoder_saits_embedding = SaitsEmbedding(
n_features * 2,
d_model,
with_pos=True,
n_max_steps=n_steps,
dropout=dropout,
)
self.decoder_saits_embedding = SaitsEmbedding(
n_features * 2,
d_model,
with_pos=True,
n_max_steps=n_pred_steps,
dropout=dropout,
)

self.encoder = TransformerEncoder(
n_encoder_layers,
d_model,
n_heads,
d_k,
d_v,
d_ffn,
dropout,
attn_dropout,
)
self.decoder = TransformerDecoder(
n_decoder_layers,
d_model,
n_heads,
d_k,
d_v,
d_ffn,
dropout,
attn_dropout,
)
self.output_projection = nn.Linear(d_model, n_pred_features)

# apply SAITS loss function to Transformer on the imputation task
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)

def forward(self, inputs: dict) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

if self.training:
X_pred, X_pred_missing_mask = inputs["X_pred"], inputs["X_pred_missing_mask"]
else:
batch_size = X.shape[0]
X_pred, X_pred_missing_mask = (
torch.zeros(batch_size, self.n_pred_steps, self.n_pred_features),
torch.ones(batch_size, self.n_pred_steps, self.n_pred_features),
)

# apply the SAITS embedding strategy, concatenate X and missing mask for input
input_X = self.encoder_saits_embedding(X, missing_mask)
# Transformer encoder processing
enc_output, _ = self.encoder(input_X)
input_X_pred = self.decoder_saits_embedding(X_pred, X_pred_missing_mask)
# Transformer decoder processing
dec_output, _, _ = self.decoder(input_X_pred, enc_output)
# project the representation from the d_model-dimensional space to the original data space for output
forecasting_result = self.output_projection(dec_output)

# ensemble the results as a dictionary for return
results = {
"forecasting_data": forecasting_result,
}

# if in training mode, return results with losses
if self.training:
# `loss` is always the item for backward propagating to update the model
results["loss"] = calc_mse(X_pred, forecasting_result, X_pred_missing_mask)

return results
28 changes: 28 additions & 0 deletions pypots/forecasting/transformer/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
Dataset class for the forecasting model Transformer.
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

from typing import Union

from ...data.dataset import BaseDataset


class DatasetForTransformer(BaseDataset):
"""Dataset for Transformer forecasting model."""

def __init__(
self,
data: Union[dict, str],
return_X_pred=True,
file_type: str = "hdf5",
):
super().__init__(
data=data,
return_X_ori=False,
return_X_pred=return_X_pred,
return_y=False,
file_type=file_type,
)
Loading
Loading