Skip to content

Commit

Permalink
refactor: normalize some code;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Oct 25, 2024
1 parent bf045b4 commit 0979d44
Show file tree
Hide file tree
Showing 12 changed files with 128 additions and 130 deletions.
22 changes: 0 additions & 22 deletions pypots/classification/csai/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,6 @@

from ...nn.modules.csai import BackboneBCSAI

# class DiceBCELoss(nn.Module):
# def __init__(self, weight=None, size_average=True):
# super(DiceBCELoss, self).__init__()
# self.bcelogits = nn.BCEWithLogitsLoss()

# def forward(self, y_score, y_out, targets, smooth=1):

# #comment out if your model contains a sigmoid or equivalent activation layer
# # inputs = F.sigmoid(inputs)

# #flatten label and prediction tensors
# BCE = self.bcelogits(y_out, targets)

# y_score = y_score.view(-1)
# targets = targets.view(-1)
# intersection = (y_score * targets).sum()
# dice_loss = 1 - (2.*intersection + smooth)/(y_score.sum() + targets.sum() + smooth)

# Dice_BCE = BCE + dice_loss

# return BCE, Dice_BCE


class _BCSAI(nn.Module):
def __init__(
Expand Down
57 changes: 36 additions & 21 deletions pypots/classification/csai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# License: BSD-3-Clause

from typing import Optional, Union
import numpy as np

import torch
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -60,31 +60,43 @@ class CSAI(BaseNNClassifier):
The batch size for training and evaluating the model.
epochs :
The number of epochs for training the model.
dropout :
The dropout rate for the model to prevent overfitting. Default is 0.5.
The number of epochs for training the model.
patience :
The patience for the early-stopping mechanism. Given a positive integer, the training process will be stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping.
The patience for the early-stopping mechanism. Given a positive integer, the training process will be
stopped when the model does not perform better after that number of epochs.
Leaving it default as None will disable the early-stopping.
optimizer :
The optimizer for model training. If not given, will use a default Adam optimizer.
The optimizer for model training.
If not given, will use a default Adam optimizer.
num_workers :
The number of subprocesses to use for data loading. 0 means data loading will be in the main process, i.e. there won't be subprocesses.
The number of subprocesses to use for data loading.
`0` means data loading will be in the main process, i.e. there won't be subprocesses.
device :
The device for the model to run on. It can be a string, a :class:torch.device object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')], the model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future.
The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them.
If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple),
then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models.
If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the
model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices).
Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future.
saving_path :
The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during training into a tensorboard file). Will not save if not given.
The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during
training into a tensorboard file). Will not save if not given.
model_saving_strategy :
The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. No model will be saved when it is set as None. The "best" strategy will only automatically save the best model after the training finished. The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. The "all" strategy will save every model after each epoch training.
The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"].
No model will be saved when it is set as None.
The "best" strategy will only automatically save the best model after the training finished.
The "better" strategy will automatically save the model during training whenever the model performs
better than in previous epochs.
The "all" strategy will save every model after each epoch training.
verbose :
Whether to print out the training logs during the training process.
Whether to print out the training logs during the training process.
"""

Expand Down Expand Up @@ -136,6 +148,9 @@ def __init__(
self.compute_intervals = compute_intervals
self.dropout = dropout
self.intervals = None
self.replacement_probabilities = None
self.mean_set = None
self.std_set = None

# Initialise empty model
self.model = _BCSAI(
Expand Down Expand Up @@ -230,7 +245,7 @@ def fit(
file_type: str = "hdf5",
) -> None:
# Create dataset
self.training_set = DatasetForCSAI(
training_set = DatasetForCSAI(
data=train_set,
file_type=file_type,
return_y=True,
Expand All @@ -239,13 +254,13 @@ def fit(
compute_intervals=self.compute_intervals,
)

self.intervals = self.training_set.intervals
self.replacement_probabilities = self.training_set.replacement_probabilities
self.mean_set = self.training_set.mean_set
self.std_set = self.training_set.std_set
self.intervals = training_set.intervals
self.replacement_probabilities = training_set.replacement_probabilities
self.mean_set = training_set.mean_set
self.std_set = training_set.std_set

train_loader = DataLoader(
self.training_set,
training_set,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
Expand Down Expand Up @@ -321,15 +336,15 @@ def predict(
num_workers=self.num_workers,
)

classificaion_results = []
classification_results = []

with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
classificaion_results.append(results["classification_pred"])
classification_results.append(results["classification_pred"])

classification = torch.cat(classificaion_results).cpu().detach().numpy()
classification = torch.cat(classification_results).cpu().detach().numpy()
result_dict = {
"classification": classification,
}
Expand Down
5 changes: 4 additions & 1 deletion pypots/imputation/csai/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ class _BCSAI(nn.Module):
Notes
-----
BCSAI is a bidirectional imputation model that uses forward and backward GRU cells to handle time-series data. It computes consistency and reconstruction losses to improve imputation accuracy. During training, the forward and backward reconstructions are combined, and losses are used to update the model. In evaluation mode, the model also outputs original data and indicating masks for further analysis.
CSAI is a bidirectional imputation model that uses forward and backward GRU cells to handle time-series data.
It computes consistency and reconstruction losses to improve imputation accuracy.
During training, the forward and backward reconstructions are combined, and losses are used to update the model.
In evaluation mode, the model also outputs original data and indicating masks for further analysis.
"""

Expand Down
57 changes: 41 additions & 16 deletions pypots/imputation/csai/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
# Created by Linglong Qian, Joseph Arul Raj <[email protected], [email protected]>
# License: BSD-3-Clause

import copy
from typing import Iterable
from ...data.dataset import BaseDataset
from typing import Union

import numpy as np
import torch
from typing import Union
import copy
from ...data.utils import parse_delta
from sklearn.preprocessing import StandardScaler

from ...data.dataset import BaseDataset
from ...data.utils import parse_delta


def normalize_csai(
data,
Expand All @@ -22,7 +24,8 @@ def normalize_csai(
compute_intervals: bool = False,
):
"""
Normalize the data based on the given mean and standard deviation, and optionally compute time intervals between observations.
Normalize the data based on the given mean and standard deviation,
and optionally compute time intervals between observations.
Parameters
----------
Expand All @@ -33,7 +36,8 @@ def normalize_csai(
The mean values for each variable, used for normalization. If empty, means will be computed from the data.
std : list of float, optional
The standard deviation values for each variable, used for normalization. If empty, std values will be computed from the data.
The standard deviation values for each variable, used for normalization.
If empty, std values will be computed from the data.
compute_intervals : bool, optional, default=False
Whether to compute the time intervals between observations for each variable.
Expand All @@ -47,10 +51,12 @@ def normalize_csai(
The mean values for each variable after normalization, either computed from the data or passed as input.
std_set : np.ndarray
The standard deviation values for each variable after normalization, either computed from the data or passed as input.
The standard deviation values for each variable after normalization,
either computed from the data or passed as input.
intervals_list : dict of int to float, optional
If `compute_intervals` is True, this will return the median time intervals between observations for each variable.
If `compute_intervals` is True, this will return the median time intervals between observations
for each variable.
"""

# Convert data to numpy array if it is a torch tensor
Expand Down Expand Up @@ -296,13 +302,23 @@ class DatasetForCSAI(BaseDataset):
Parameters
----------
data :
The dataset for model input, which can be either a dictionary or a path string to a data file. If it's a dictionary, `X` should be an array-like structure with shape [n_samples, sequence length (n_steps), n_features], containing the time-series data, and it can have missing values. Optionally, the dictionary can include `y`, an array-like structure with shape [n_samples], representing the labels of `X`. If `data` is a path string, it should point to a data file (e.g., h5 file) that contains key-value pairs like a dictionary, including keys for `X` and possibly `y`.
The dataset for model input, which can be either a dictionary or a path string to a data file.
If it's a dictionary, `X` should be an array-like structure
with shape [n_samples, sequence length (n_steps), n_features], containing the time-series data,
and it can have missing values. Optionally, the dictionary can include `y`,
an array-like structure with shape [n_samples], representing the labels of `X`.
If `data` is a path string, it should point to a data file (e.g., h5 file) that contains key-value pairs like
a dictionary, including keys for `X` and possibly `y`.
return_X_ori :
Whether to return the original time-series data (`X_ori`) when fetching data samples, useful for evaluation purposes.
Whether to return the original time-series data (`X_ori`) when fetching data samples,
useful for evaluation purposes.
return_y :
Whether to return classification labels in the `__getitem__()` method if they exist in the dataset. If `True`, labels will be included in the returned data samples, which is useful for training classification models. If `False`, the labels won't be returned, suitable for testing or validation stages.
Whether to return classification labels in the `__getitem__()` method if they exist in the dataset.
If `True`, labels will be included in the returned data samples,
which is useful for training classification models.
If `False`, the labels won't be returned, suitable for testing or validation stages.
file_type :
The type of the data file if `data` is a path string, such as "hdf5".
Expand All @@ -317,20 +333,29 @@ class DatasetForCSAI(BaseDataset):
Whether to compute time intervals between observations for handling irregular time-series data.
replacement_probabilities :
Optional precomputed probabilities for sampling missing values. If not provided, they will be calculated during the initialization of the dataset.
Optional precomputed probabilities for sampling missing values.
If not provided, they will be calculated during the initialization of the dataset.
normalise_mean :
A list of mean values for normalizing the input features. If not provided, they will be computed during initialization.
A list of mean values for normalizing the input features.
If not provided, they will be computed during initialization.
normalise_std :
A list of standard deviation values for normalizing the input features. If not provided, they will be computed during initialization.
A list of standard deviation values for normalizing the input features.
If not provided, they will be computed during initialization.
training :
Whether the dataset is used for training. If `False`, it will adjust how data is processed, particularly for evaluation and testing phases.
Whether the dataset is used for training.
If `False`, it will adjust how data is processed, particularly for evaluation and testing phases.
Notes
-----
The DatasetForCSAI class is designed for bidirectional imputation of time-series data, handling both forward and backward directions to improve imputation accuracy. It supports on-the-fly data normalization and missing value simulation, making it suitable for training and evaluating deep learning models like CSAI. The class can work with large datasets stored on disk, leveraging lazy-loading to minimize memory usage, and supports both training and testing scenarios, adjusting data handling as needed.
The DatasetForCSAI class is designed for bidirectional imputation of time-series data,
handling both forward and backward directions to improve imputation accuracy.
It supports on-the-fly data normalization and missing value simulation,
making it suitable for training and evaluating deep learning models like CSAI.
The class can work with large datasets stored on disk, leveraging lazy-loading to minimize memory usage,
and supports both training and testing scenarios, adjusting data handling as needed.
"""

Expand Down
17 changes: 9 additions & 8 deletions pypots/imputation/csai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
import torch
from torch.utils.data import DataLoader


from .core import _BCSAI
from .data import DatasetForCSAI
from ..base import BaseNNImputer
from ...data.checking import key_in_data_set
from ...optim.adam import Adam
from ...optim.base import Optimizer

Expand Down Expand Up @@ -146,6 +144,9 @@ def __init__(
self.step_channels = step_channels
self.compute_intervals = compute_intervals
self.intervals = None
self.replacement_probabilities = None
self.mean_set = None
self.std_set = None

# Initialise model
self.model = _BCSAI(
Expand Down Expand Up @@ -238,16 +239,16 @@ def fit(
file_type: str = "hdf5",
) -> None:

self.training_set = DatasetForCSAI(
training_set = DatasetForCSAI(
train_set, False, False, file_type, self.removal_percent, self.increase_factor, self.compute_intervals
)
self.intervals = self.training_set.intervals
self.replacement_probabilities = self.training_set.replacement_probabilities
self.mean_set = self.training_set.mean_set
self.std_set = self.training_set.std_set
self.intervals = training_set.intervals
self.replacement_probabilities = training_set.replacement_probabilities
self.mean_set = training_set.mean_set
self.std_set = training_set.std_set

training_loader = DataLoader(
self.training_set,
training_set,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
Expand Down
5 changes: 1 addition & 4 deletions pypots/imputation/segrnn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@

# Created by Shengsheng Lin

from typing import Optional

from typing import Callable
import torch.nn as nn

from ...nn.modules.segrnn import BackboneSegRNN
from ...nn.modules.saits import SaitsLoss
from ...nn.modules.segrnn import BackboneSegRNN


class _SegRNN(nn.Module):
Expand Down
11 changes: 6 additions & 5 deletions pypots/nn/modules/csai/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
# Created by Linglong Qian, Joseph Arul Raj <[email protected], [email protected]>
# License: BSD-3-Clause

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from .layers import FeatureRegression, Decay, Decay_obs, PositionalEncoding, Conv1dWithInit, TorchTransformerEncoder
from ....utils.metrics import calc_mae

Expand Down Expand Up @@ -91,7 +92,7 @@ class BackboneCSAI(nn.Module):
"""

def __init__(self, n_steps, n_features, rnn_hidden_size, step_channels, medians_df=None):
super(BackboneCSAI, self).__init__()
super().__init__()

if medians_df is not None:
self.medians_tensor = torch.tensor(list(medians_df.values())).float()
Expand Down Expand Up @@ -134,7 +135,7 @@ def forward(self, x, mask, deltas, last_obs, h=None):

decay_factor = self.weighted_obs(deltas - medians.unsqueeze(1))

if h == None:
if h is None:
data_last_obs = self.input_projection(last_obs.permute(0, 2, 1)).permute(0, 2, 1)
data_decay_factor = self.input_projection(decay_factor.permute(0, 2, 1)).permute(0, 2, 1)

Expand Down Expand Up @@ -191,7 +192,7 @@ def forward(self, x, mask, deltas, last_obs, h=None):

class BackboneBCSAI(nn.Module):
def __init__(self, n_steps, n_features, rnn_hidden_size, step_channels, medians_df=None):
super(BackboneBCSAI, self).__init__()
super().__init__()

self.model_f = BackboneCSAI(n_steps, n_features, rnn_hidden_size, step_channels, medians_df)
self.model_b = BackboneCSAI(n_steps, n_features, rnn_hidden_size, step_channels, medians_df)
Expand Down
Loading

0 comments on commit 0979d44

Please sign in to comment.