-
-
Notifications
You must be signed in to change notification settings - Fork 128
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add classification model GRU-D;
- Loading branch information
Showing
8 changed files
with
361 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,7 @@ Install with the latest code on GitHub: | |
| Imputation | Neural Network | SAITS: Self-Attention-based Imputation for Time Series | 2022 | [^1] | | ||
| Imputation | Neural Network | Transformer | 2017 | [^2] [^1] | | ||
| Imputation,<br>Classification | Neural Network | BRITS: Bidirectional Recurrent Imputation for Time Series | 2018 | [^3] | | ||
| Classification | Neural Network | GRU-D | 2018 | [^4] | | ||
|
||
--- | ||
‼️ PyPOTS is currently under development. If you like it and look forward to its growth, <ins>please give PyPOTS a star and watch it to keep you posted on its progress and to let me know that its development is meaningful</ins>. If you have any feedback, or want to contribute ideas/suggestions or share time-series related algorithms/papers, please join PyPOTS community and <a alt='GitHub Discussions' href='https://github.com/WenjieDu/PyPOTS/discussions'><img align='center' src='https://img.shields.io/badge/Chat-in_Discussions-green?logo=github&color=60A98D'></a>, or [drop me an email](mailto:[email protected]). | ||
|
@@ -45,4 +46,5 @@ Thank you all for your attention! 😃 | |
|
||
[^1]: Du, W., Cote, D., & Liu, Y. (2022). SAITS: Self-Attention-based Imputation for Time Series. ArXiv, abs/2202.08516. | ||
[^2]: Vaswani, A., Shazeer, N.M., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, L., & Polosukhin, I. (2017). Attention is All you Need. NeurIPS 2017. | ||
[^3]: Cao, W., Wang, D., Li, J., Zhou, H., Li, L., & Li, Y. (2018). BRITS: Bidirectional Recurrent Imputation for Time Series. NeurIPS 2018. | ||
[^3]: Cao, W., Wang, D., Li, J., Zhou, H., Li, L., & Li, Y. (2018). BRITS: Bidirectional Recurrent Imputation for Time Series. NeurIPS 2018. | ||
[^4]: Che, Z., Purushotham, S., Cho, K., Sontag, D.A., & Liu, Y. (2018). Recurrent Neural Networks for Multivariate Time Series with Missing Values. Scientific Reports, 8. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,5 @@ | |
# Created by Wenjie Du <[email protected]> | ||
# License: GPL-v3 | ||
|
||
from .brits import ( | ||
BRITS | ||
) | ||
from .brits import BRITS | ||
from .grud import GRUD |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
""" | ||
PyTorch GRU-D model. | ||
""" | ||
|
||
# Created by Wenjie Du <[email protected]> | ||
# License: GLP-v3 | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.utils.data import DataLoader | ||
|
||
from pypots.classification.base import BaseNNClassifier | ||
from pypots.data import DatasetForGRUD | ||
from pypots.imputation.brits import TemporalDecay | ||
|
||
|
||
class _GRUD(nn.Module): | ||
def __init__(self, seq_len, n_features, rnn_hidden_size, n_classes, device=None): | ||
super(_GRUD, self).__init__() | ||
self.seq_len = seq_len | ||
self.n_features = n_features | ||
self.rnn_hidden_size = rnn_hidden_size | ||
self.n_classes = n_classes | ||
self.device = device | ||
|
||
# create models | ||
self.rnn_cell = nn.GRUCell(self.n_features * 2 + self.rnn_hidden_size, self.rnn_hidden_size) | ||
self.temp_decay_h = TemporalDecay(input_size=self.n_features, output_size=self.rnn_hidden_size, diag=False) | ||
self.temp_decay_x = TemporalDecay(input_size=self.n_features, output_size=self.n_features, diag=True) | ||
self.classifier = nn.Linear(self.rnn_hidden_size, self.n_classes) | ||
|
||
def classify(self, inputs): | ||
values = inputs['X'] | ||
masks = inputs['missing_mask'] | ||
deltas = inputs['deltas'] | ||
empirical_mean = inputs['empirical_mean'] | ||
X_filledLOCF = inputs['X_filledLOCF'] | ||
|
||
hidden_state = torch.zeros((values.size()[0], self.rnn_hidden_size), device=self.device) | ||
|
||
for t in range(self.seq_len): | ||
# for data, [batch, time, features] | ||
x = values[:, t, :] # values | ||
m = masks[:, t, :] # mask | ||
d = deltas[:, t, :] # delta, time gap | ||
x_filledLOCF = X_filledLOCF[:, t, :] | ||
|
||
gamma_h = self.temp_decay_h(d) | ||
gamma_x = self.temp_decay_x(d) | ||
hidden_state = hidden_state * gamma_h | ||
|
||
x_h = gamma_x * x_filledLOCF + (1 - gamma_x) * empirical_mean | ||
x_replaced = m * x + (1 - m) * x_h | ||
inputs = torch.cat([x_replaced, hidden_state, m], dim=1) | ||
hidden_state = self.rnn_cell(inputs, hidden_state) | ||
|
||
logits = self.classifier(hidden_state) | ||
prediction = torch.softmax(logits, dim=1) | ||
# print(f'logits: {logits}, logits.shape: {logits.shape}') | ||
# print(f'prediction: {prediction}') | ||
return prediction | ||
|
||
def forward(self, inputs): | ||
""" Forward processing of GRU-D. | ||
Parameters | ||
---------- | ||
inputs : dict, | ||
The input data. | ||
Returns | ||
------- | ||
dict, | ||
A dictionary includes all results. | ||
""" | ||
prediction = self.classify(inputs) | ||
classification_loss = F.nll_loss(torch.log(prediction), inputs['label']) | ||
results = { | ||
'prediction': prediction, | ||
'loss': classification_loss | ||
} | ||
return results | ||
|
||
|
||
class GRUD(BaseNNClassifier): | ||
""" GRU-D implementation of BaseClassifier. | ||
Attributes | ||
---------- | ||
model : object, | ||
The underlying GRU-D model. | ||
optimizer : object, | ||
The optimizer for model training. | ||
data_loader : object, | ||
The data loader for dataset loading. | ||
Parameters | ||
---------- | ||
rnn_hidden_size : int, | ||
The size of the RNN hidden state. | ||
learning_rate : float (0,1), | ||
The learning rate parameter for the optimizer. | ||
weight_decay : float in (0,1), | ||
The weight decay parameter for the optimizer. | ||
epochs : int, | ||
The number of training epochs. | ||
patience : int, | ||
The number of epochs with loss non-decreasing before early stopping the training. | ||
batch_size : int, | ||
The batch size of the training input. | ||
device : | ||
Run the model on which device. | ||
""" | ||
|
||
def __init__(self, | ||
rnn_hidden_size, | ||
n_classes, | ||
learning_rate=1e-3, | ||
epochs=100, | ||
patience=10, | ||
batch_size=32, | ||
weight_decay=1e-5, | ||
device=None): | ||
super(GRUD, self).__init__(n_classes, learning_rate, epochs, patience, batch_size, weight_decay, device) | ||
|
||
self.rnn_hidden_size = rnn_hidden_size | ||
|
||
def fit(self, train_X, train_y, val_X=None, val_y=None): | ||
""" Fit the model on the given training data. | ||
Parameters | ||
---------- | ||
train_X : array, shape [n_samples, sequence length (time steps), n_features], | ||
Time-series vectors. | ||
train_y : array, | ||
Classification labels. | ||
Returns | ||
------- | ||
self : object, | ||
Trained model. | ||
""" | ||
assert len(train_X.shape) == 3, f'train_X should have 3 dimensions [n_samples, seq_len, n_features],' \ | ||
f'while train_X.shape={train_X.shape}' | ||
if val_X is not None: | ||
assert len(train_X.shape) == 3, f'val_X should have 3 dimensions [n_samples, seq_len, n_features],' \ | ||
f'while val_X.shape={train_X.shape}' | ||
|
||
_, seq_len, n_features = train_X.shape | ||
self.model = _GRUD(seq_len, n_features, self.rnn_hidden_size, self.n_classes, self.device) | ||
self.model = self.model.to(self.device) | ||
training_set = DatasetForGRUD(train_X, train_y) | ||
training_loader = DataLoader(training_set, batch_size=self.batch_size, shuffle=True) | ||
|
||
if val_X is None: | ||
self._train_model(training_loader) | ||
else: | ||
val_set = DatasetForGRUD(val_X, val_y) | ||
val_loader = DataLoader(val_set, batch_size=self.batch_size, shuffle=False) | ||
self._train_model(training_loader, val_loader) | ||
|
||
self.model.load_state_dict(self.best_model_dict) | ||
self.model.eval() # set the model as eval status to freeze it. | ||
return self | ||
|
||
def input_data_processing(self, data): | ||
# fetch data | ||
indices, X, X_filledLOCF, missing_mask, deltas, empirical_mean, label = \ | ||
map(lambda x: x.to(self.device), data) | ||
# assemble input data | ||
inputs = { | ||
'indices': indices, | ||
'X': X, | ||
'X_filledLOCF': X_filledLOCF, | ||
'missing_mask': missing_mask, | ||
'deltas': deltas, | ||
'empirical_mean': empirical_mean, | ||
'label': label, | ||
} | ||
return inputs | ||
|
||
def classify(self, X): | ||
self.model.eval() # set the model as eval status to freeze it. | ||
test_set = DatasetForGRUD(X) | ||
test_loader = DataLoader(test_set, batch_size=self.batch_size, shuffle=False) | ||
prediction_collector = [] | ||
|
||
with torch.no_grad(): | ||
for idx, data in enumerate(test_loader): | ||
# cannot use input_data_processing, cause here has no label | ||
indices, X, X_filledLOCF, missing_mask, deltas, empirical_mean = \ | ||
map(lambda x: x.to(self.device), data) | ||
# assemble input data | ||
inputs = { | ||
'indices': indices, | ||
'X': X, | ||
'X_filledLOCF': X_filledLOCF, | ||
'missing_mask': missing_mask, | ||
'deltas': deltas, | ||
'empirical_mean': empirical_mean, | ||
} | ||
|
||
prediction = self.model.classify(inputs) | ||
prediction_collector.append(prediction) | ||
|
||
predictions = torch.cat(prediction_collector) | ||
return predictions.cpu().detach().numpy() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.