-
-
Notifications
You must be signed in to change notification settings - Fork 120
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #504 from WenjieDu/dev
Add ModernTCN
- Loading branch information
Showing
14 changed files
with
1,222 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
""" | ||
The package of the partially-observed time-series imputation model ModernTCN. | ||
Refer to the paper | ||
`Donghao Luo, and Xue Wang. | ||
ModernTCN: A Modern Pure Convolution Structure for General Time Series Analysis. | ||
In The Twelfth International Conference on Learning Representations. 2024. | ||
<https://openreview.net/pdf?id=vpJMJerXHU>`_ | ||
Notes | ||
----- | ||
This implementation is inspired by the official one https://github.com/luodhhh/ModernTCN | ||
""" | ||
|
||
# Created by Wenjie Du <[email protected]> | ||
# License: BSD-3-Clause | ||
|
||
|
||
from .model import ModernTCN | ||
|
||
__all__ = [ | ||
"ModernTCN", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
""" | ||
The core wrapper assembles the submodules of ModernTCN imputation model | ||
and takes over the forward progress of the algorithm. | ||
""" | ||
|
||
# Created by Wenjie Du <[email protected]> | ||
# License: BSD-3-Clause | ||
|
||
import torch.nn as nn | ||
|
||
from ...nn.functional import nonstationary_norm, nonstationary_denorm | ||
from ...nn.modules.moderntcn import BackboneModernTCN | ||
from ...nn.modules.patchtst.layers import FlattenHead | ||
from ...utils.metrics import calc_mse | ||
|
||
|
||
class _ModernTCN(nn.Module): | ||
def __init__( | ||
self, | ||
n_steps, | ||
n_features, | ||
patch_size, | ||
patch_stride, | ||
downsampling_ratio, | ||
ffn_ratio, | ||
num_blocks: list, | ||
large_size: list, | ||
small_size: list, | ||
dims: list, | ||
small_kernel_merged: bool = False, | ||
backbone_dropout: float = 0.1, | ||
head_dropout: float = 0.1, | ||
use_multi_scale: bool = True, | ||
individual: bool = False, | ||
apply_nonstationary_norm: bool = False, | ||
): | ||
super().__init__() | ||
|
||
self.apply_nonstationary_norm = apply_nonstationary_norm | ||
|
||
self.backbone = BackboneModernTCN( | ||
n_steps, | ||
n_features, | ||
n_features, | ||
patch_size, | ||
patch_stride, | ||
downsampling_ratio, | ||
ffn_ratio, | ||
num_blocks, | ||
large_size, | ||
small_size, | ||
dims, | ||
small_kernel_merged, | ||
backbone_dropout, | ||
head_dropout, | ||
use_multi_scale, | ||
individual, | ||
) | ||
|
||
# for the imputation task, the output dim is the same as input dim | ||
self.projection = FlattenHead( | ||
self.backbone.head_nf, | ||
n_steps, | ||
n_features, | ||
head_dropout, | ||
individual, | ||
) | ||
|
||
def forward(self, inputs: dict, training: bool = True) -> dict: | ||
X, missing_mask = inputs["X"], inputs["missing_mask"] | ||
|
||
if self.apply_nonstationary_norm: | ||
# Normalization from Non-stationary Transformer | ||
X, means, stdev = nonstationary_norm(X, missing_mask) | ||
|
||
in_X = X.permute(0, 2, 1) | ||
in_X = self.backbone(in_X) | ||
reconstruction = self.projection(in_X) | ||
reconstruction = reconstruction.permute(0, 2, 1) | ||
|
||
if self.apply_nonstationary_norm: | ||
# De-Normalization from Non-stationary Transformer | ||
reconstruction = nonstationary_denorm(reconstruction, means, stdev) | ||
|
||
imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction | ||
results = { | ||
"imputed_data": imputed_data, | ||
} | ||
|
||
# if in training mode, return results with losses | ||
if training: | ||
loss = calc_mse(reconstruction, inputs["X_ori"], inputs["indicating_mask"]) | ||
results["loss"] = loss | ||
|
||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
""" | ||
Dataset class for ModernTCN. | ||
""" | ||
|
||
# Created by Wenjie Du <[email protected]> | ||
# License: BSD-3-Clause | ||
|
||
from typing import Union | ||
|
||
from ..saits.data import DatasetForSAITS | ||
|
||
|
||
class DatasetForModernTCN(DatasetForSAITS): | ||
"""Actually ModernTCN uses the same data strategy as SAITS, needs MIT for training.""" | ||
|
||
def __init__( | ||
self, | ||
data: Union[dict, str], | ||
return_X_ori: bool, | ||
return_y: bool, | ||
file_type: str = "hdf5", | ||
rate: float = 0.2, | ||
): | ||
super().__init__(data, return_X_ori, return_y, file_type, rate) |
Oops, something went wrong.