Skip to content

Commit

Permalink
Merge pull request #253 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Adding TimesNet, refactoring code, and updating docs
  • Loading branch information
WenjieDu authored Dec 1, 2023
2 parents 5647c81 + efca871 commit c597bb5
Show file tree
Hide file tree
Showing 33 changed files with 1,906 additions and 1,030 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ This functionality is implemented with the [Microsoft NNI](https://github.com/mi
| **Type** | **Abbr.** | **Full name of the algorithm/model** | **Year** |
| Neural Net | SAITS | Self-Attention-based Imputation for Time Series [^1] | 2023 |
| Neural Net | Transformer | Attention is All you Need [^2];<br>Self-Attention-based Imputation for Time Series [^1];<br><sub>Note: proposed in [^2], and re-implemented as an imputation model in [^1].</sub> | 2017 |
| Neural Net | TimesNet | Temporal 2D-Variation Modeling for General Time Series Analysis [^14] | 2023 |
| Neural Net | CSDI | Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation [^12] | 2021 |
| Neural Net | US-GAN | Unsupervised GAN for Multivariate Time Series Imputation [^10] | 2021 |
| Neural Net | GP-VAE | Gaussian Process Variational Autoencoder [^11] | 2020 |
Expand Down Expand Up @@ -302,6 +303,7 @@ PyPOTS community is open, transparent, and surely friendly. Let's work together
[^11]: Fortuin, V., Baranchuk, D., Raetsch, G. & Mandt, S. (2020). [GP-VAE: Deep Probabilistic Time Series Imputation](https://proceedings.mlr.press/v108/fortuin20a.html). *AISTATS 2020*.
[^12]: Tashiro, Y., Song, J., Song, Y., & Ermon, S. (2021). [CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation](https://proceedings.neurips.cc/paper/2021/hash/cfe8504bda37b575c70ee1a8276f3486-Abstract.html). *NeurIPS 2021*.
[^13]: Rubin, D. B. (1976). [Inference and missing data](https://academic.oup.com/biomet/article-abstract/63/3/581/270932). *Biometrika*.
[^14]: Wu, H., Hu, T., Liu, Y., Zhou, H., Wang, J., & Long, M. (2023). [TimesNet: Temporal 2d-variation modeling for general time series analysis](https://openreview.net/forum?id=ju_Uqw384Oq). *ICLR 2023*
<details>
Expand Down
3 changes: 2 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,9 @@ This functionality is implemented with the `Microsoft NNI <https://github.com/mi
============================== ================ ======================================================================================== ====== =========
Task Type Algorithm Year Reference
============================== ================ ======================================================================================== ====== =========
Imputation Neural Net SAITS (Self-Attention-based Imputation for Time Series) 2022 :cite:`du2023SAITS`
Imputation Neural Net SAITS (Self-Attention-based Imputation for Time Series) 2023 :cite:`du2023SAITS`
Imputation Neural Net Transformer 2017 :cite:`vaswani2017Transformer`, :cite:`du2023SAITS`
Imputation Neural Net TimesNet 2023 :cite:`wu2023timesnet`
Imputation Neural Net US-GAN (Unsupervised GAN for Multivariate Time Series Imputation) 2021 :cite:`miao2021SSGAN`
Imputation Neural Net CSDI (Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation) 2021 :cite:`tashiro2021csdi`
Imputation Neural Net GP-VAE (Gaussian Process Variational Autoencoder) 2020 :cite:`fortuin2020GPVAEDeep`
Expand Down
8 changes: 8 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -458,3 +458,11 @@ @article{rubin1976missing
volume = {63},
year = {1976}
}

@inproceedings{wu2023timesnet,
title={{TimesNet: Temporal 2D-Variation Modeling for General Time Series Analysis}},
author={Haixu Wu and Tengge Hu and Yong Liu and Hang Zhou and Jianmin Wang and Mingsheng Long},
booktitle={The Eleventh International Conference on Learning Representations },
year={2023},
url={https://openreview.net/forum?id=ju_Uqw384Oq}
}
65 changes: 32 additions & 33 deletions pypots/classification/brits/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,37 +89,36 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
ret_b = self._reverse(self.rits_b(inputs, "backward"))

classification_pred = (ret_f["prediction"] + ret_b["prediction"]) / 2
if not training:
# if not in training mode, return the classification result only
return {"classification_pred": classification_pred}

ret_f["classification_loss"] = F.nll_loss(
torch.log(ret_f["prediction"]), inputs["label"]
)
ret_b["classification_loss"] = F.nll_loss(
torch.log(ret_b["prediction"]), inputs["label"]
)
consistency_loss = self._get_consistency_loss(
ret_f["imputed_data"], ret_b["imputed_data"]
)
classification_loss = (
ret_f["classification_loss"] + ret_b["classification_loss"]
) / 2
reconstruction_loss = (
ret_f["reconstruction_loss"] + ret_b["reconstruction_loss"]
) / 2

loss = (
consistency_loss
+ reconstruction_loss * self.reconstruction_weight
+ classification_loss * self.classification_weight
)

results = {
"classification_pred": classification_pred,
"consistency_loss": consistency_loss,
"classification_loss": classification_loss,
"reconstruction_loss": reconstruction_loss,
"loss": loss,
}
results = {"classification_pred": classification_pred}

# if in training mode, return results with losses
if training:
ret_f["classification_loss"] = F.nll_loss(
torch.log(ret_f["prediction"]), inputs["label"]
)
ret_b["classification_loss"] = F.nll_loss(
torch.log(ret_b["prediction"]), inputs["label"]
)
consistency_loss = self._get_consistency_loss(
ret_f["imputed_data"], ret_b["imputed_data"]
)
classification_loss = (
ret_f["classification_loss"] + ret_b["classification_loss"]
) / 2
reconstruction_loss = (
ret_f["reconstruction_loss"] + ret_b["reconstruction_loss"]
) / 2

results["consistency_loss"] = consistency_loss
results["classification_loss"] = classification_loss
results["reconstruction_loss"] = reconstruction_loss

# `loss` is always the item for backward propagating to update the model
loss = (
consistency_loss
+ reconstruction_loss * self.reconstruction_weight
+ classification_loss * self.classification_weight
)
results["loss"] = loss

return results
20 changes: 8 additions & 12 deletions pypots/classification/grud/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,14 @@ def forward(self, inputs: dict, training: bool = True) -> dict:

logits = self.classifier(hidden_state)
classification_pred = torch.softmax(logits, dim=1)
results = {"classification_pred": classification_pred}

if not training:
# if not in training mode, return the classification result only
return {"classification_pred": classification_pred}
# if in training mode, return results with losses
if training:
torch.log(classification_pred)
classification_loss = F.nll_loss(
torch.log(classification_pred), inputs["label"]
)
results["loss"] = classification_loss

torch.log(classification_pred)
classification_loss = F.nll_loss(
torch.log(classification_pred), inputs["label"]
)

results = {
"classification_pred": classification_pred,
"loss": classification_loss,
}
return results
19 changes: 7 additions & 12 deletions pypots/classification/raindrop/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,18 +262,13 @@ def classify(self, inputs: dict) -> torch.Tensor:

def forward(self, inputs, training=True):
classification_pred = self.classify(inputs)
if not training:
# if not in training mode, return the classification result only
return {"classification_pred": classification_pred}
results = {"classification_pred": classification_pred}

classification_loss = F.nll_loss(
torch.log(classification_pred), inputs["label"]
)

results = {
"prediction": classification_pred,
"loss": classification_loss
# 'distance': distance,
}
# if in training mode, return results with losses
if training:
classification_loss = F.nll_loss(
torch.log(classification_pred), inputs["label"]
)
results["loss"] = classification_loss

return results
4 changes: 2 additions & 2 deletions pypots/clustering/crli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def _train_model(
with torch.no_grad():
for idx, data in enumerate(val_loader):
inputs = self._assemble_input_for_validating(data)
results = self.model.forward(inputs, return_loss=True)
results = self.model.forward(inputs, training=True)
epoch_val_loss_G_collector.append(
results["generation_loss"].sum().item()
)
Expand Down Expand Up @@ -424,7 +424,7 @@ def predict(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
inputs = self.model.forward(inputs, return_loss=False)
inputs = self.model.forward(inputs, training=False)
clustering_latent_collector.append(inputs["fcn_latent"])
if return_latent_vars:
imputation_collector.append(inputs["imputation_latent"])
Expand Down
5 changes: 3 additions & 2 deletions pypots/clustering/crli/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def forward(
self,
inputs: dict,
training_object: str = "generator",
return_loss: bool = True,
training: bool = True,
) -> dict:
X = inputs["X"]
missing_mask = inputs["missing_mask"]
Expand All @@ -76,7 +76,7 @@ def forward(
inputs["fcn_latent"] = fcn_latent

# return results directly, skip loss calculation to reduce inference time
if not return_loss:
if not training:
return inputs

if training_object == "discriminator":
Expand Down Expand Up @@ -106,4 +106,5 @@ def forward(
l_kmeans = torch.trace(HTH) - torch.trace(FTHTHF) # k-means loss
loss_gene = l_G + l_pre + l_rec + l_kmeans * self.lambda_kmeans
losses["generation_loss"] = loss_gene

return losses
139 changes: 69 additions & 70 deletions pypots/clustering/vader/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,18 +173,15 @@ def forward(
stddev_tilde,
) = self.get_results(X, missing_mask)

if not training and not pretrain:
results = {
"mu_tilde": mu_tilde,
"stddev_tilde": stddev_tilde,
"mu": mu_c,
"var": var_c,
"phi": phi_c,
"z": z,
"imputation_latent": X_reconstructed,
}
# if only run clustering, then no need to calculate loss
return results
results = {
"mu_tilde": mu_tilde,
"stddev_tilde": stddev_tilde,
"mu": mu_c,
"var": var_c,
"phi": phi_c,
"z": z,
"imputation_latent": X_reconstructed,
}

# calculate the reconstruction loss
unscaled_reconstruction_loss = cal_mse(X_reconstructed, X, missing_mask)
Expand All @@ -194,66 +191,68 @@ def forward(
* self.d_input
/ missing_mask.sum()
)

if pretrain:
results = {"loss": reconstruction_loss, "z": z}
results["loss"] = reconstruction_loss
return results

# calculate the latent loss
var_tilde = torch.exp(stddev_tilde)
stddev_c = torch.log(var_c + self.eps)
log_2pi = torch.log(torch.tensor([2 * torch.pi], device=device))
log_phi_c = torch.log(phi_c + self.eps)

batch_size = z.shape[0]

ii, jj = torch.meshgrid(
torch.arange(self.n_clusters, dtype=torch.int64, device=device),
torch.arange(batch_size, dtype=torch.int64, device=device),
indexing="ij",
)
ii = ii.flatten()
jj = jj.flatten()

lsc_b = stddev_c.index_select(dim=0, index=ii)
mc_b = mu_c.index_select(dim=0, index=ii)
sc_b = var_c.index_select(dim=0, index=ii)
z_b = z.index_select(dim=0, index=jj)
log_pdf_z = -0.5 * (lsc_b + log_2pi + torch.square(z_b - mc_b) / sc_b)
log_pdf_z = log_pdf_z.reshape([batch_size, self.n_clusters, self.d_mu_stddev])

log_p = log_phi_c + log_pdf_z.sum(dim=2)
lse_p = log_p.logsumexp(dim=1, keepdim=True)
log_gamma_c = log_p - lse_p
gamma_c = torch.exp(log_gamma_c)

term1 = torch.log(var_c + self.eps)
st_b = var_tilde.index_select(dim=0, index=jj)
sc_b = var_c.index_select(dim=0, index=ii)
term2 = torch.reshape(
st_b / (sc_b + self.eps), [batch_size, self.n_clusters, self.d_mu_stddev]
)
mt_b = mu_tilde.index_select(dim=0, index=jj)
mc_b = mu_c.index_select(dim=0, index=ii)
term3 = torch.reshape(
torch.square(mt_b - mc_b) / (sc_b + self.eps),
[batch_size, self.n_clusters, self.d_mu_stddev],
)

latent_loss1 = 0.5 * torch.sum(
gamma_c * torch.sum(term1 + term2 + term3, dim=2), dim=1
)
latent_loss2 = -torch.sum(gamma_c * (log_phi_c - log_gamma_c), dim=1)
latent_loss3 = -0.5 * torch.sum(1 + stddev_tilde, dim=1)

latent_loss1 = latent_loss1.mean()
latent_loss2 = latent_loss2.mean()
latent_loss3 = latent_loss3.mean()
latent_loss = latent_loss1 + latent_loss2 + latent_loss3

results = {
"loss": reconstruction_loss + self.alpha * latent_loss,
"z": z,
"imputation_latent": X_reconstructed,
}
# if in training mode, return results with losses
if training:
# calculate the latent loss for model training
var_tilde = torch.exp(stddev_tilde)
stddev_c = torch.log(var_c + self.eps)
log_2pi = torch.log(torch.tensor([2 * torch.pi], device=device))
log_phi_c = torch.log(phi_c + self.eps)

batch_size = z.shape[0]

ii, jj = torch.meshgrid(
torch.arange(self.n_clusters, dtype=torch.int64, device=device),
torch.arange(batch_size, dtype=torch.int64, device=device),
indexing="ij",
)
ii = ii.flatten()
jj = jj.flatten()

lsc_b = stddev_c.index_select(dim=0, index=ii)
mc_b = mu_c.index_select(dim=0, index=ii)
sc_b = var_c.index_select(dim=0, index=ii)
z_b = z.index_select(dim=0, index=jj)
log_pdf_z = -0.5 * (lsc_b + log_2pi + torch.square(z_b - mc_b) / sc_b)
log_pdf_z = log_pdf_z.reshape(
[batch_size, self.n_clusters, self.d_mu_stddev]
)

log_p = log_phi_c + log_pdf_z.sum(dim=2)
lse_p = log_p.logsumexp(dim=1, keepdim=True)
log_gamma_c = log_p - lse_p
gamma_c = torch.exp(log_gamma_c)

term1 = torch.log(var_c + self.eps)
st_b = var_tilde.index_select(dim=0, index=jj)
sc_b = var_c.index_select(dim=0, index=ii)
term2 = torch.reshape(
st_b / (sc_b + self.eps),
[batch_size, self.n_clusters, self.d_mu_stddev],
)
mt_b = mu_tilde.index_select(dim=0, index=jj)
mc_b = mu_c.index_select(dim=0, index=ii)
term3 = torch.reshape(
torch.square(mt_b - mc_b) / (sc_b + self.eps),
[batch_size, self.n_clusters, self.d_mu_stddev],
)

latent_loss1 = 0.5 * torch.sum(
gamma_c * torch.sum(term1 + term2 + term3, dim=2), dim=1
)
latent_loss2 = -torch.sum(gamma_c * (log_phi_c - log_gamma_c), dim=1)
latent_loss3 = -0.5 * torch.sum(1 + stddev_tilde, dim=1)

latent_loss1 = latent_loss1.mean()
latent_loss2 = latent_loss2.mean()
latent_loss3 = latent_loss3.mean()
latent_loss = latent_loss1 + latent_loss2 + latent_loss3

results["loss"] = reconstruction_loss + self.alpha * latent_loss

return results
4 changes: 3 additions & 1 deletion pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@
# License: BSD-3-Clause

from .brits import BRITS
from .csdi import CSDI
from .gpvae import GPVAE
from .locf import LOCF
from .mrnn import MRNN
from .saits import SAITS
from .timesnet import TimesNet
from .transformer import Transformer
from .usgan import USGAN
from .csdi import CSDI

__all__ = [
"SAITS",
"Transformer",
"TimesNet",
"BRITS",
"MRNN",
"LOCF",
Expand Down
Loading

0 comments on commit c597bb5

Please sign in to comment.