diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml
index 7070b00c..08db4458 100644
--- a/.github/workflows/testing.yml
+++ b/.github/workflows/testing.yml
@@ -37,14 +37,14 @@ jobs:
run: |
coverage run --source=pypots -m pytest
- - name: Write the LCOV report
- run: |
- coverage lcov
-
- - name: Submit report
- uses: coverallsapp/github-action@master
- env:
- NODE_COVERALLS_DEBUG: 1
- with:
- github-token: ${{ secrets.ACCESS_TOKEN }}
- path-to-lcov: 'coverage.lcov'
\ No newline at end of file
+# - name: Write the LCOV report
+# run: |
+# coverage lcov
+#
+# - name: Submit report
+# uses: coverallsapp/github-action@master
+# env:
+# NODE_COVERALLS_DEBUG: 1
+# with:
+# github-token: ${{ secrets.ACCESS_TOKEN }}
+# path-to-lcov: 'coverage.lcov'
\ No newline at end of file
diff --git a/README.md b/README.md
index d62756a1..a5103502 100644
--- a/README.md
+++ b/README.md
@@ -37,6 +37,7 @@ Install with the latest code on GitHub:
| Imputation | Neural Network | SAITS: Self-Attention-based Imputation for Time Series | 2022 | [^1] |
| Imputation | Neural Network | Transformer | 2017 | [^2] [^1] |
| Imputation,
Classification | Neural Network | BRITS: Bidirectional Recurrent Imputation for Time Series | 2018 | [^3] |
+| Classification | Neural Network | GRU-D | 2018 | [^4] |
---
‼️ PyPOTS is currently under development. If you like it and look forward to its growth, please give PyPOTS a star and watch it to keep you posted on its progress and to let me know that its development is meaningful. If you have any feedback, or want to contribute ideas/suggestions or share time-series related algorithms/papers, please join PyPOTS community and
, or [drop me an email](mailto:wenjay.du@gmail.com).
@@ -45,4 +46,5 @@ Thank you all for your attention! 😃
[^1]: Du, W., Cote, D., & Liu, Y. (2022). SAITS: Self-Attention-based Imputation for Time Series. ArXiv, abs/2202.08516.
[^2]: Vaswani, A., Shazeer, N.M., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, L., & Polosukhin, I. (2017). Attention is All you Need. NeurIPS 2017.
-[^3]: Cao, W., Wang, D., Li, J., Zhou, H., Li, L., & Li, Y. (2018). BRITS: Bidirectional Recurrent Imputation for Time Series. NeurIPS 2018.
\ No newline at end of file
+[^3]: Cao, W., Wang, D., Li, J., Zhou, H., Li, L., & Li, Y. (2018). BRITS: Bidirectional Recurrent Imputation for Time Series. NeurIPS 2018.
+[^4]: Che, Z., Purushotham, S., Cho, K., Sontag, D.A., & Liu, Y. (2018). Recurrent Neural Networks for Multivariate Time Series with Missing Values. Scientific Reports, 8.
\ No newline at end of file
diff --git a/pypots/classification/__init__.py b/pypots/classification/__init__.py
index 93822711..6b9a6225 100644
--- a/pypots/classification/__init__.py
+++ b/pypots/classification/__init__.py
@@ -5,6 +5,5 @@
# Created by Wenjie Du
# License: GPL-v3
-from .brits import (
- BRITS
-)
+from .brits import BRITS
+from .grud import GRUD
diff --git a/pypots/classification/grud.py b/pypots/classification/grud.py
new file mode 100644
index 00000000..45c26e56
--- /dev/null
+++ b/pypots/classification/grud.py
@@ -0,0 +1,207 @@
+"""
+PyTorch GRU-D model.
+"""
+
+# Created by Wenjie Du
+# License: GLP-v3
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+
+from pypots.classification.base import BaseNNClassifier
+from pypots.data import DatasetForGRUD
+from pypots.imputation.brits import TemporalDecay
+
+
+class _GRUD(nn.Module):
+ def __init__(self, seq_len, n_features, rnn_hidden_size, n_classes, device=None):
+ super(_GRUD, self).__init__()
+ self.seq_len = seq_len
+ self.n_features = n_features
+ self.rnn_hidden_size = rnn_hidden_size
+ self.n_classes = n_classes
+ self.device = device
+
+ # create models
+ self.rnn_cell = nn.GRUCell(self.n_features * 2 + self.rnn_hidden_size, self.rnn_hidden_size)
+ self.temp_decay_h = TemporalDecay(input_size=self.n_features, output_size=self.rnn_hidden_size, diag=False)
+ self.temp_decay_x = TemporalDecay(input_size=self.n_features, output_size=self.n_features, diag=True)
+ self.classifier = nn.Linear(self.rnn_hidden_size, self.n_classes)
+
+ def classify(self, inputs):
+ values = inputs['X']
+ masks = inputs['missing_mask']
+ deltas = inputs['deltas']
+ empirical_mean = inputs['empirical_mean']
+ X_filledLOCF = inputs['X_filledLOCF']
+
+ hidden_state = torch.zeros((values.size()[0], self.rnn_hidden_size), device=self.device)
+
+ for t in range(self.seq_len):
+ # for data, [batch, time, features]
+ x = values[:, t, :] # values
+ m = masks[:, t, :] # mask
+ d = deltas[:, t, :] # delta, time gap
+ x_filledLOCF = X_filledLOCF[:, t, :]
+
+ gamma_h = self.temp_decay_h(d)
+ gamma_x = self.temp_decay_x(d)
+ hidden_state = hidden_state * gamma_h
+
+ x_h = gamma_x * x_filledLOCF + (1 - gamma_x) * empirical_mean
+ x_replaced = m * x + (1 - m) * x_h
+ inputs = torch.cat([x_replaced, hidden_state, m], dim=1)
+ hidden_state = self.rnn_cell(inputs, hidden_state)
+
+ logits = self.classifier(hidden_state)
+ prediction = torch.softmax(logits, dim=1)
+ # print(f'logits: {logits}, logits.shape: {logits.shape}')
+ # print(f'prediction: {prediction}')
+ return prediction
+
+ def forward(self, inputs):
+ """ Forward processing of GRU-D.
+
+ Parameters
+ ----------
+ inputs : dict,
+ The input data.
+
+ Returns
+ -------
+ dict,
+ A dictionary includes all results.
+ """
+ prediction = self.classify(inputs)
+ classification_loss = F.nll_loss(torch.log(prediction), inputs['label'])
+ results = {
+ 'prediction': prediction,
+ 'loss': classification_loss
+ }
+ return results
+
+
+class GRUD(BaseNNClassifier):
+ """ GRU-D implementation of BaseClassifier.
+
+ Attributes
+ ----------
+ model : object,
+ The underlying GRU-D model.
+ optimizer : object,
+ The optimizer for model training.
+ data_loader : object,
+ The data loader for dataset loading.
+
+ Parameters
+ ----------
+ rnn_hidden_size : int,
+ The size of the RNN hidden state.
+ learning_rate : float (0,1),
+ The learning rate parameter for the optimizer.
+ weight_decay : float in (0,1),
+ The weight decay parameter for the optimizer.
+ epochs : int,
+ The number of training epochs.
+ patience : int,
+ The number of epochs with loss non-decreasing before early stopping the training.
+ batch_size : int,
+ The batch size of the training input.
+ device :
+ Run the model on which device.
+ """
+
+ def __init__(self,
+ rnn_hidden_size,
+ n_classes,
+ learning_rate=1e-3,
+ epochs=100,
+ patience=10,
+ batch_size=32,
+ weight_decay=1e-5,
+ device=None):
+ super(GRUD, self).__init__(n_classes, learning_rate, epochs, patience, batch_size, weight_decay, device)
+
+ self.rnn_hidden_size = rnn_hidden_size
+
+ def fit(self, train_X, train_y, val_X=None, val_y=None):
+ """ Fit the model on the given training data.
+
+ Parameters
+ ----------
+ train_X : array, shape [n_samples, sequence length (time steps), n_features],
+ Time-series vectors.
+ train_y : array,
+ Classification labels.
+
+ Returns
+ -------
+ self : object,
+ Trained model.
+ """
+ assert len(train_X.shape) == 3, f'train_X should have 3 dimensions [n_samples, seq_len, n_features],' \
+ f'while train_X.shape={train_X.shape}'
+ if val_X is not None:
+ assert len(train_X.shape) == 3, f'val_X should have 3 dimensions [n_samples, seq_len, n_features],' \
+ f'while val_X.shape={train_X.shape}'
+
+ _, seq_len, n_features = train_X.shape
+ self.model = _GRUD(seq_len, n_features, self.rnn_hidden_size, self.n_classes, self.device)
+ self.model = self.model.to(self.device)
+ training_set = DatasetForGRUD(train_X, train_y)
+ training_loader = DataLoader(training_set, batch_size=self.batch_size, shuffle=True)
+
+ if val_X is None:
+ self._train_model(training_loader)
+ else:
+ val_set = DatasetForGRUD(val_X, val_y)
+ val_loader = DataLoader(val_set, batch_size=self.batch_size, shuffle=False)
+ self._train_model(training_loader, val_loader)
+
+ self.model.load_state_dict(self.best_model_dict)
+ self.model.eval() # set the model as eval status to freeze it.
+ return self
+
+ def input_data_processing(self, data):
+ # fetch data
+ indices, X, X_filledLOCF, missing_mask, deltas, empirical_mean, label = \
+ map(lambda x: x.to(self.device), data)
+ # assemble input data
+ inputs = {
+ 'indices': indices,
+ 'X': X,
+ 'X_filledLOCF': X_filledLOCF,
+ 'missing_mask': missing_mask,
+ 'deltas': deltas,
+ 'empirical_mean': empirical_mean,
+ 'label': label,
+ }
+ return inputs
+
+ def classify(self, X):
+ self.model.eval() # set the model as eval status to freeze it.
+ test_set = DatasetForGRUD(X)
+ test_loader = DataLoader(test_set, batch_size=self.batch_size, shuffle=False)
+ prediction_collector = []
+
+ with torch.no_grad():
+ for idx, data in enumerate(test_loader):
+ # cannot use input_data_processing, cause here has no label
+ indices, X, X_filledLOCF, missing_mask, deltas, empirical_mean = \
+ map(lambda x: x.to(self.device), data)
+ # assemble input data
+ inputs = {
+ 'indices': indices,
+ 'X': X,
+ 'X_filledLOCF': X_filledLOCF,
+ 'missing_mask': missing_mask,
+ 'deltas': deltas,
+ 'empirical_mean': empirical_mean,
+ }
+
+ prediction = self.model.classify(inputs)
+ prediction_collector.append(prediction)
+
+ predictions = torch.cat(prediction_collector)
+ return predictions.cpu().detach().numpy()
diff --git a/pypots/data/DatasetForBRITS.py b/pypots/data/DatasetForBRITS.py
index e44ecbe5..8376a88a 100644
--- a/pypots/data/DatasetForBRITS.py
+++ b/pypots/data/DatasetForBRITS.py
@@ -11,6 +11,38 @@
from pypots.data.base import BaseDataset
+def parse_delta(missing_mask):
+ """ Generate time-gap (delta) matrix from missing masks.
+
+ Parameters
+ ----------
+ missing_mask : array, shape of [seq_len, n_features]
+ Binary masks indicate missing values.
+
+ Returns
+ -------
+ delta, array,
+ Delta matrix indicates time gaps of missing values.
+ Its math definition please refer to :cite:`che2018MissingData`.
+ """
+
+ assert len(missing_mask.shape) == 3, f'missing_mask should has 3 dimensions, ' \
+ f'shape like [n_samples, seq_len, n_features], ' \
+ f'while the input is {missing_mask.shape}'
+ n_samples, seq_len, n_features = missing_mask.shape
+ delta_collector = []
+ for m_mask in missing_mask:
+ delta = []
+ for step in range(seq_len):
+ if step == 0:
+ delta.append(np.zeros(n_features))
+ else:
+ delta.append(np.ones(n_features) + (1 - m_mask[step]) * delta[-1])
+ delta = np.asarray(delta)
+ delta_collector.append(delta)
+ return np.asarray(delta_collector)
+
+
class DatasetForBRITS(BaseDataset):
""" Dataset class for BRITS.
@@ -29,10 +61,10 @@ def __init__(self, X, y=None):
# Training will take too much time if we put delta calculation in __getitem__().
forward_missing_mask = (~np.isnan(X)).astype(np.float32)
forward_X = np.nan_to_num(X)
- forward_delta = self.parse_delta(forward_missing_mask)
+ forward_delta = parse_delta(forward_missing_mask)
backward_X = np.flip(forward_X, axis=1).copy()
backward_missing_mask = np.flip(forward_missing_mask, axis=1).copy()
- backward_delta = self.parse_delta(backward_missing_mask)
+ backward_delta = parse_delta(backward_missing_mask)
self.data = {
'forward': {
@@ -47,42 +79,6 @@ def __init__(self, X, y=None):
},
}
- @staticmethod
- def parse_delta(missing_mask):
- """ Generate time-gap (delta) matrix from missing masks.
-
- Parameters
- ----------
- missing_mask : array, shape of [seq_len, n_features]
- Binary masks indicate missing values.
-
- Returns
- -------
- delta, array,
- Delta matrix indicates time gaps of missing values.
- Its math definition please refer to :cite:`che2018MissingData`.
- """
-
- assert len(missing_mask.shape) == 3, f'missing_mask should has 3 dimensions, ' \
- f'shape like [n_samples, seq_len, n_features], ' \
- f'while the input is {missing_mask.shape}'
- n_samples, seq_len, n_features = missing_mask.shape
- delta_collector = []
- for m_mask in missing_mask:
- delta = []
- for step in range(seq_len):
- if step == 0:
- delta.append(np.zeros(n_features))
- else:
- delta.append(np.ones(n_features) + (1 - m_mask[step]) * delta[-1])
- delta = np.asarray(delta)
- delta_collector.append(delta)
- return np.asarray(delta_collector)
-
- # TODO: preprocess the dataset and cache it, mainly for saving the time of calculating deltas
- def preprocess_and_cache(self):
- pass
-
def __getitem__(self, idx):
""" Fetch data according to index.
diff --git a/pypots/data/DatasetForGRUD.py b/pypots/data/DatasetForGRUD.py
new file mode 100644
index 00000000..f2da2e79
--- /dev/null
+++ b/pypots/data/DatasetForGRUD.py
@@ -0,0 +1,75 @@
+"""
+Dataset class for model GRUD.
+"""
+
+# Created by Wenjie Du
+# License: GLP-v3
+
+import numpy as np
+import torch
+
+from pypots.data.DatasetForBRITS import parse_delta
+from pypots.data.base import BaseDataset
+from pypots.imputation import LOCF
+
+
+class DatasetForGRUD(BaseDataset):
+ """ Dataset class for model GRUD.
+
+ Parameters
+ ----------
+ X : array-like, shape of [n_samples, seq_len, n_features]
+ Time-series feature vector.
+ y : array-like, shape of [n_samples], optional, default=None,
+ Classification labels of according time-series samples.
+ """
+
+ def __init__(self, X, y=None):
+ super(DatasetForGRUD, self).__init__(X, y)
+
+ self.locf = LOCF()
+ self.missing_mask = (~np.isnan(X)).astype(np.float32)
+ self.X = np.nan_to_num(X)
+ self.deltas = parse_delta(self.missing_mask)
+ self.X_filledLOCF = self.locf.impute(X)
+ self.empirical_mean = \
+ np.sum(self.missing_mask * self.X, axis=(0, 1)) / np.sum(self.missing_mask, axis=(0, 1))
+
+ def __getitem__(self, idx):
+ """ Fetch data according to index.
+
+ Parameters
+ ----------
+ idx : int,
+ The index to fetch the specified sample.
+
+ Returns
+ -------
+ dict,
+ A dict contains
+ index : int tensor,
+ The index of the sample.
+ X : tensor,
+ The feature vector for model input.
+ X_filledLOCF: tensor,
+ The feature vector filled with last observations.
+ missing_mask : tensor,
+ The mask indicates all missing values in X.
+ delta : tensor,
+ The delta matrix contains time gaps of missing values.
+ empirical_mean : tensor,
+ Mean values of features.
+ """
+ sample = [
+ torch.tensor(idx),
+ self.X[idx].astype('float32'),
+ self.X_filledLOCF[idx].astype('float32'),
+ self.missing_mask[idx].astype('float32'),
+ self.deltas[idx].astype('float32'),
+ self.empirical_mean.astype('float32'),
+ ]
+
+ if self.y is not None:
+ sample.append(torch.tensor(self.y[idx], dtype=torch.long))
+
+ return sample
diff --git a/pypots/data/__init__.py b/pypots/data/__init__.py
index 2eca3396..32a310c7 100644
--- a/pypots/data/__init__.py
+++ b/pypots/data/__init__.py
@@ -6,6 +6,7 @@
# License: GPL-v3
from .DatasetForBRITS import DatasetForBRITS
+from .DatasetForGRUD import DatasetForGRUD
from .DatasetForMIT import DatasetForMIT
from .generating import generate_random_walk, generate_random_walk_for_classification
from .integration import (
diff --git a/pypots/tests/test_classification.py b/pypots/tests/test_classification.py
index a484a05b..9ca172df 100644
--- a/pypots/tests/test_classification.py
+++ b/pypots/tests/test_classification.py
@@ -9,7 +9,7 @@
import numpy as np
-from pypots.classification import BRITS
+from pypots.classification import BRITS, GRUD
from pypots.data import generate_random_walk_for_classification
@@ -40,5 +40,32 @@ def test_impute(self):
predictions = self.brits.classify(self.X)
+class TestGRUD(unittest.TestCase):
+ def setUp(self) -> None:
+ # generate time-series classification data
+ X, y = generate_random_walk_for_classification(n_classes=3, n_samples_each_class=10)
+ X[X < 0] = np.nan # create missing values
+ self.X = X
+ self.y = y
+ self.grud = GRUD(256, n_classes=3, epochs=1)
+ self.grud.fit(self.X, self.y)
+
+ def test_parameters(self):
+ assert (hasattr(self.grud, 'model')
+ and self.grud.model is not None)
+
+ assert (hasattr(self.grud, 'optimizer')
+ and self.grud.optimizer is not None)
+
+ assert hasattr(self.grud, 'best_loss')
+ self.assertNotEqual(self.grud.best_loss, float('inf'))
+
+ assert (hasattr(self.grud, 'best_model_dict')
+ and self.grud.best_model_dict is not None)
+
+ def test_impute(self):
+ predictions = self.grud.classify(self.X)
+
+
if __name__ == '__main__':
unittest.main()