diff --git a/pypots/classification/__init__.py b/pypots/classification/__init__.py index 73f6524d..4344ad92 100644 --- a/pypots/classification/__init__.py +++ b/pypots/classification/__init__.py @@ -6,10 +6,12 @@ # License: BSD-3-Clause from .brits import BRITS +from .csai import CSAI from .grud import GRUD from .raindrop import Raindrop __all__ = [ + "CSAI", "BRITS", "GRUD", "Raindrop", diff --git a/pypots/classification/csai/__init__.py b/pypots/classification/csai/__init__.py new file mode 100644 index 00000000..5ea14ae3 --- /dev/null +++ b/pypots/classification/csai/__init__.py @@ -0,0 +1,20 @@ +""" +The package including the modules of CSAI. + +Refer to the paper +`Linglong Qian, Zina Ibrahim, Hugh Logan Ellis, Ao Zhang, Yuezhou Zhang, Tao Wang, Richard Dobson. +Knowledge Enhanced Conditional Imputation for Healthcare Time-series. +In Arxiv, 2024. +`_ + +Notes +----- +This implementation is inspired by the official one the official implementation https://github.com/LinglongQian/CSAI. + +""" + +from .model import CSAI + +__all__ = [ + "CSAI", +] \ No newline at end of file diff --git a/pypots/classification/csai/core.py b/pypots/classification/csai/core.py new file mode 100644 index 00000000..30f052f6 --- /dev/null +++ b/pypots/classification/csai/core.py @@ -0,0 +1,123 @@ +""" + +""" + +# Created by Linglong Qian, Joseph Arul Raj +# License: BSD-3-Clause + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...nn.modules.csai import BackboneBCSAI + +# class DiceBCELoss(nn.Module): +# def __init__(self, weight=None, size_average=True): +# super(DiceBCELoss, self).__init__() +# self.bcelogits = nn.BCEWithLogitsLoss() + +# def forward(self, y_score, y_out, targets, smooth=1): + +# #comment out if your model contains a sigmoid or equivalent activation layer +# # inputs = F.sigmoid(inputs) + +# #flatten label and prediction tensors +# BCE = self.bcelogits(y_out, targets) + +# y_score = y_score.view(-1) +# targets = targets.view(-1) +# intersection = (y_score * targets).sum() +# dice_loss = 1 - (2.*intersection + smooth)/(y_score.sum() + targets.sum() + smooth) + +# Dice_BCE = BCE + dice_loss + +# return BCE, Dice_BCE + + +class _BCSAI(nn.Module): + def __init__( + self, + n_steps: int, + n_features: int, + rnn_hidden_size: int, + imputation_weight: float, + consistency_weight: float, + classification_weight: float, + n_classes: int, + step_channels: int, + dropout: float = 0.5, + intervals=None, + ): + super().__init__() + self.n_steps = n_steps + self.n_features = n_features + self.rnn_hidden_size = rnn_hidden_size + self.imputation_weight = imputation_weight + self.consistency_weight = consistency_weight + self.classification_weight = classification_weight + self.n_classes = n_classes + self.step_channels = step_channels + self.intervals = intervals + + # create models + self.model = BackboneBCSAI(n_steps, n_features, rnn_hidden_size, step_channels, intervals) + self.f_classifier = nn.Linear(self.rnn_hidden_size, n_classes) + self.b_classifier = nn.Linear(self.rnn_hidden_size, n_classes) + self.imputer = nn.Linear(self.rnn_hidden_size, n_features) + self.dropout = nn.Dropout(dropout) + + def forward(self, inputs: dict, training: bool = True) -> dict: + + ( + imputed_data, + f_reconstruction, + b_reconstruction, + f_hidden_states, + b_hidden_states, + consistency_loss, + reconstruction_loss, + ) = self.model(inputs) + + results = { + "imputed_data": imputed_data, + } + + f_logits = self.f_classifier(self.dropout(f_hidden_states)) + b_logits = self.b_classifier(self.dropout(b_hidden_states)) + + # f_prediction = torch.sigmoid(f_logits) + # b_prediction = torch.sigmoid(b_logits) + + f_prediction = torch.softmax(f_logits, dim=1) + b_prediction = torch.softmax(b_logits, dim=1) + classification_pred = (f_prediction + b_prediction) / 2 + + results = { + "imputed_data": imputed_data, + "classification_pred": classification_pred, + } + + # if in training mode, return results with losses + if training: + # criterion = DiceBCELoss().to(imputed_data.device) + results["consistency_loss"] = consistency_loss + results["reconstruction_loss"] = reconstruction_loss + # print(inputs["labels"].unsqueeze(1)) + f_classification_loss = F.nll_loss(torch.log(f_prediction), inputs["labels"]) + b_classification_loss = F.nll_loss(torch.log(b_prediction), inputs["labels"]) + # f_classification_loss, _ = criterion(f_prediction, f_logits, inputs["labels"].unsqueeze(1).float()) + # b_classification_loss, _ = criterion(b_prediction, b_logits, inputs["labels"].unsqueeze(1).float()) + classification_loss = (f_classification_loss + b_classification_loss) + + loss = ( + self.consistency_weight * consistency_loss + + self.imputation_weight * reconstruction_loss + + self.classification_weight * classification_loss + ) + + results["loss"] = loss + results["classification_loss"] = classification_loss + results["f_reconstruction"] = f_reconstruction + results["b_reconstruction"] = b_reconstruction + + return results \ No newline at end of file diff --git a/pypots/classification/csai/data.py b/pypots/classification/csai/data.py new file mode 100644 index 00000000..caeb5005 --- /dev/null +++ b/pypots/classification/csai/data.py @@ -0,0 +1,39 @@ +""" + +""" + +# Created by Joseph Arul Raj +# License: BSD-3-Clause + +from typing import Union +from ...imputation.csai.data import DatasetForCSAI as DatasetForCSAI_Imputation + + + +class DatasetForCSAI(DatasetForCSAI_Imputation): + def __init__(self, + data: Union[dict, str], + file_type: str = "hdf5", + return_y: bool = True, + removal_percent: float = 0.0, + increase_factor: float = 0.1, + compute_intervals: bool = False, + replacement_probabilities = None, + normalise_mean : list = [], + normalise_std: list = [], + training: bool = True + ): + super().__init__( + data=data, + return_X_ori=False, + return_y=return_y, + file_type=file_type, + removal_percent=removal_percent, + increase_factor=increase_factor, + compute_intervals=compute_intervals, + replacement_probabilities=replacement_probabilities, + normalise_mean=normalise_mean, + normalise_std=normalise_std, + training=training + ) + \ No newline at end of file diff --git a/pypots/classification/csai/model.py b/pypots/classification/csai/model.py new file mode 100644 index 00000000..fb9bd5b5 --- /dev/null +++ b/pypots/classification/csai/model.py @@ -0,0 +1,358 @@ + +""" + +""" + +# Created by Linglong Qian, Joseph Arul Raj +# License: BSD-3-Clause + +from typing import Optional, Union +import numpy as np +import torch +from torch.utils.data import DataLoader + +from .core import _BCSAI +from .data import DatasetForCSAI +from ..base import BaseNNClassifier +from ...optim.adam import Adam +from ...optim.base import Optimizer + + +class CSAI(BaseNNClassifier): + + """ + The PyTorch implementation of the CSAI model. + + Parameters + + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + rnn_hidden_size : + The size of the RNN hidden state. + + imputation_weight : + The loss weight for the imputation task. + + consistency_weight : + The loss weight for the consistency task. + + classification_weight : + The loss weight for the classification task. + + n_classes : + The number of classes in the classification task. + + removal_percent : + The percentage of data to be removed during training for simulating missingness. + + increase_factor : + The factor to increase the frequency of missing value occurrences. + + compute_intervals : + Whether to compute time intervals between observations during data processing. + + step_channels : + The number of step channels for the model. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + dropout : + The dropout rate for the model to prevent overfitting. Default is 0.5. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + + optimizer : + The optimizer for model training. If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. 0 means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:torch.device object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')], the model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. No model will be saved when it is set as None. The "best" strategy will only automatically save the best model after the training finished. The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. The "all" strategy will save every model after each epoch training. + + verbose : + Whether to print out the training logs during the training process. + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + rnn_hidden_size: int, + imputation_weight: float, + consistency_weight: float, + classification_weight: float, + n_classes: int, + removal_percent: int, + increase_factor: float, + compute_intervals: bool, + step_channels:int, + batch_size: int, + epochs: int, + dropout: float = 0.5, + patience: Union[int, None] = None, + optimizer: Optimizer = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Union[str, None] = "best", + verbose: bool = True + ): + super().__init__( + n_classes, + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + verbose, + ) + + self.n_steps = n_steps + self.n_features = n_features + self.rnn_hidden_size = rnn_hidden_size + self.imputation_weight = imputation_weight + self.consistency_weight = consistency_weight + self.classification_weight = classification_weight + self.removal_percent = removal_percent + self.increase_factor = increase_factor + self.step_channels = step_channels + self.compute_intervals = compute_intervals + self.dropout = dropout + self.intervals = None + + # Initialise empty model + self.model = _BCSAI( + n_steps=self.n_steps, + n_features=self.n_features, + rnn_hidden_size=self.rnn_hidden_size, + imputation_weight=self.imputation_weight, + consistency_weight=self.consistency_weight, + classification_weight=self.classification_weight, + n_classes=self.n_classes, + step_channels=self.step_channels, + dropout=self.dropout, + intervals=self.intervals, + ) + + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + + def _assemble_input_for_training(self, data: list, training=True) -> dict: + # extract data + sample = data['sample'] + ( + indices, + X, + missing_mask, + deltas, + last_obs, + back_X, + back_missing_mask, + back_deltas, + back_last_obs, + labels + ) = self._send_data_to_given_device(sample) + + inputs = { + "indices": indices, + "labels": labels, + "forward": { + "X": X, + "missing_mask": missing_mask, + "deltas": deltas, + "last_obs": last_obs, + }, + "backward": { + "X": back_X, + "missing_mask": back_missing_mask, + "deltas": back_deltas, + "last_obs": back_last_obs, + }, + } + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + # extract data + sample = data['sample'] + ( + indices, + X, + missing_mask, + deltas, + last_obs, + back_X, + back_missing_mask, + back_deltas, + back_last_obs, + X_ori, + indicating_mask, + ) = self._send_data_to_given_device(sample) + + # assemble input data + inputs = { + "indices": indices, + "forward": { + "X": X, + "missing_mask": missing_mask, + "deltas": deltas, + "last_obs": last_obs, + }, + "backward": { + "X": back_X, + "missing_mask": back_missing_mask, + "deltas": back_deltas, + "last_obs": back_last_obs, + }, + # "X_ori": X_ori, + # "indicating_mask": indicating_mask, + } + + return inputs + + def fit( + self, + train_set, + val_set= None, + file_type: str = "hdf5", + )-> None: + # Create dataset + self.training_set = DatasetForCSAI( + data=train_set, + file_type=file_type, + return_y=True, + removal_percent=self.removal_percent, + increase_factor=self.increase_factor, + compute_intervals=self.compute_intervals, + ) + + self.intervals = self.training_set.intervals + self.replacement_probabilities = self.training_set.replacement_probabilities + self.mean_set = self.training_set.mean_set + self.std_set = self.training_set.std_set + + train_loader = DataLoader( + self.training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + val_set = DatasetForCSAI( + data=val_set, + file_type=file_type, + return_y=True, + removal_percent=self.removal_percent, + increase_factor=self.increase_factor, + compute_intervals=self.compute_intervals, + replacement_probabilities=self.replacement_probabilities, + normalise_mean=self.mean_set, + normalise_std=self.std_set, + ) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + # Create model + self.model = _BCSAI( + n_steps=self.n_steps, + n_features=self.n_features, + rnn_hidden_size=self.rnn_hidden_size, + imputation_weight=self.imputation_weight, + consistency_weight=self.consistency_weight, + classification_weight=self.classification_weight, + n_classes=self.n_classes, + step_channels=self.step_channels, + dropout=self.dropout, + intervals=self.intervals, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + # train the model + self._train_model(train_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() + + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") + + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> dict: + + self.model.eval() + test_set = DatasetForCSAI( + data=test_set, + file_type=file_type, + return_y=False, + removal_percent=self.removal_percent, + increase_factor=self.increase_factor, + compute_intervals=self.compute_intervals, + replacement_probabilities=self.replacement_probabilities, + normalise_mean=self.mean_set, + normalise_std=self.std_set, + training=False, + ) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + classificaion_results = [] + + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + classificaion_results.append(results['classification_pred']) + + + classification = torch.cat(classificaion_results).cpu().detach().numpy() + result_dict = { + "classification": classification, + } + return result_dict + + def classify( + self, + test_set, + file_type: str = "hdf5", + ): + + result_dict = self.predict(test_set, file_type) + return result_dict['classification'] \ No newline at end of file diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index 19a7e2c6..6600dcfd 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -6,6 +6,7 @@ # License: BSD-3-Clause from .brits import BRITS +from .csai import CSAI from .csdi import CSDI from .gpvae import GPVAE from .mrnn import MRNN @@ -85,4 +86,5 @@ "Median", "Lerp", "TEFN", + "CSAI", ] diff --git a/pypots/imputation/csai/__init__.py b/pypots/imputation/csai/__init__.py new file mode 100644 index 00000000..529cb0ce --- /dev/null +++ b/pypots/imputation/csai/__init__.py @@ -0,0 +1,23 @@ +""" +The package including the modules of CSAI. + +Refer to the paper +`Linglong Qian, Zina Ibrahim, Hugh Logan Ellis, Ao Zhang, Yuezhou Zhang, Tao Wang, Richard Dobson. +Knowledge Enhanced Conditional Imputation for Healthcare Time-series. +In Arxiv, 2024. +`_ + +Notes +----- +This implementation is inspired by the official one the official implementation https://github.com/LinglongQian/CSAI. + +""" + +# Created by Linglong Qian, Joseph Arul Raj +# License: BSD-3-Clause + +from .model import CSAI + +__all__ = [ + "CSAI", +] \ No newline at end of file diff --git a/pypots/imputation/csai/core.py b/pypots/imputation/csai/core.py new file mode 100644 index 00000000..fc9ea6f9 --- /dev/null +++ b/pypots/imputation/csai/core.py @@ -0,0 +1,118 @@ +""" + +""" + +# Created by Linglong Qian, Joseph Arul Raj +# License: BSD-3-Clause + +import torch.nn as nn +from ...nn.modules.csai.backbone import BackboneBCSAI + + +class _BCSAI(nn.Module): + """ + Attributes + ---------- + n_steps : + sequence length (number of time steps) + + n_features : + number of features (input dimensions) + + rnn_hidden_size : + the hidden size of the GRU cell + + step_channels : + number of channels for each step in the sequence + + intervals : + time intervals between the observations, used for handling irregular time-series + + consistency_weight : + weight assigned to the consistency loss during training + + imputation_weight : + weight assigned to the reconstruction loss during training + + model : + the underlying BackboneBCSAI model that handles forward and backward pass imputation + + Parameters + ---------- + n_steps : + sequence length (number of time steps) + + n_features : + number of features (input dimensions) + + rnn_hidden_size : + the hidden size of the GRU cell + + step_channels : + number of channels for each step in the sequence + + intervals : + time intervals between observations + + consistency_weight : + weight assigned to the consistency loss + + imputation_weight : + weight assigned to the reconstruction loss + + Notes + ----- + BCSAI is a bidirectional imputation model that uses forward and backward GRU cells to handle time-series data. It computes consistency and reconstruction losses to improve imputation accuracy. During training, the forward and backward reconstructions are combined, and losses are used to update the model. In evaluation mode, the model also outputs original data and indicating masks for further analysis. + + """ + def __init__(self, + n_steps, + n_features, + rnn_hidden_size, + step_channels, + consistency_weight, + imputation_weight, + intervals=None, + ): + super().__init__() + self.n_steps = n_steps + self.n_features = n_features + self.rnn_hidden_size = rnn_hidden_size + self.step_channels = step_channels + self.intervals = intervals + self.consistency_weight = consistency_weight + self.imputation_weight = imputation_weight + + self.model = BackboneBCSAI(n_steps, n_features, rnn_hidden_size, step_channels, intervals) + + def forward(self, inputs:dict, training:bool = True) -> dict: + ( + imputed_data, + f_reconstruction, + b_reconstruction, + f_hidden_states, + b_hidden_states, + consistency_loss, + reconstruction_loss, + ) = self.model(inputs) + + results = { + "imputed_data": imputed_data, + } + + # if in training mode, return results with losses + if training: + results["consistency_loss"] = consistency_loss + results["reconstruction_loss"] = reconstruction_loss + loss = self.consistency_weight * consistency_loss + self.imputation_weight * reconstruction_loss + + # `loss` is always the item for backward propagating to update the model + results["loss"] = loss + # results["reconstruction"] = (f_reconstruction + b_reconstruction) / 2 + results["f_reconstruction"] = f_reconstruction + results["b_reconstruction"] = b_reconstruction + if not training: + results["X_ori"] = inputs["X_ori"] + results["indicating_mask"] = inputs["indicating_mask"] + + return results diff --git a/pypots/imputation/csai/data.py b/pypots/imputation/csai/data.py new file mode 100644 index 00000000..6e4e481a --- /dev/null +++ b/pypots/imputation/csai/data.py @@ -0,0 +1,519 @@ +""" + +""" + +# Created by Linglong Qian, Joseph Arul Raj +# License: BSD-3-Clause + +from typing import Iterable +from ...data.dataset import BaseDataset +import numpy as np +import torch +from typing import Union +import copy +from ...data.utils import parse_delta +from sklearn.preprocessing import StandardScaler + +def normalize_csai( + data, + mean: list = None, + std: list = None, + compute_intervals: bool = False, +): + """ + Normalize the data based on the given mean and standard deviation, and optionally compute time intervals between observations. + + Parameters + ---------- + data : np.ndarray + The input time-series data of shape [n_patients, n_hours, n_variables], which may contain missing values (NaNs). + + mean : list of float, optional + The mean values for each variable, used for normalization. If empty, means will be computed from the data. + + std : list of float, optional + The standard deviation values for each variable, used for normalization. If empty, std values will be computed from the data. + + compute_intervals : bool, optional, default=False + Whether to compute the time intervals between observations for each variable. + + Returns + ------- + data : torch.Tensor + The normalized time-series data with the same shape as the input data, moved to the specified device. + + mean_set : np.ndarray + The mean values for each variable after normalization, either computed from the data or passed as input. + + std_set : np.ndarray + The standard deviation values for each variable after normalization, either computed from the data or passed as input. + + intervals_list : dict of int to float, optional + If `compute_intervals` is True, this will return the median time intervals between observations for each variable. + """ + + # Convert data to numpy array if it is a torch tensor + if isinstance(data, torch.Tensor): + data = data.cpu().numpy() + + n_patients, n_hours, n_variables = data.shape + + # Flatten data for easier computation of statistics + reshaped_data = data.reshape(-1, n_variables) + + # Use StandardScaler for normalization + scaler = StandardScaler() + + # Update condition to handle empty list as well + if mean is None or std is None or len(mean) == 0 or len(std) == 0: + # Fit the scaler on the data (ignores NaNs during the fitting process) + scaler.fit(reshaped_data) + mean_set = scaler.mean_ + std_set = scaler.scale_ + else: + # Use provided mean and std by directly setting them in the scaler + scaler.mean_ = np.array(mean) + scaler.scale_ = np.array(std) + mean_set = np.array(mean) + std_set = np.array(std) + + # Transform data using scaler, which ignores NaNs + scaled_data = scaler.transform(reshaped_data) + + # Reshape back to original shape [n_patients, n_hours, n_variables] + normalized_data = scaled_data.reshape(n_patients, n_hours, n_variables) + + # Optimized interval calculation considering NaNs in each patient + if compute_intervals: + intervals_list = {} + + for v in range(n_variables): + all_intervals = [] + # Loop over each patient + for p in range(n_patients): + # Get non-NaN observation indices for the current patient and variable + valid_time_points = np.where(~np.isnan(data[p, :, v]))[0] + + # If the patient has more than one valid observation, compute time intervals + if len(valid_time_points) > 1: + # Calculate time differences between consecutive observations + intervals = np.diff(valid_time_points) + all_intervals.extend(intervals) + + # Compute the median interval for the current variable + intervals_list[v] = np.median(all_intervals) if all_intervals else np.nan + else: + intervals_list = None + + return normalized_data, mean_set, std_set, intervals_list + + +def compute_last_obs(data, masks): + """ + Compute the last observed values for each time step. + + Parameters: + - data (np.array): Original data array of shape [T, D]. + - masks (np.array): Binary masks indicating where data is not NaN, of shape [T, D]. + + Returns: + - last_obs (np.array): Array of the last observed values, of shape [T, D]. + """ + T, D = masks.shape + last_obs = np.full((T, D), np.nan) # Initialize last observed values with NaNs + last_obs_val = np.full(D, np.nan) # Initialize last observed values for first time step with NaNs + + for t in range(1, T): # Start from t=1, keeping first row as NaN + mask = masks[t - 1] + # Update last observed values based on previous time step + last_obs_val[mask] = data[t - 1, mask] + # Assign last observed values to the current time step + last_obs[t] = last_obs_val + + return last_obs + +def adjust_probability_vectorized( + obs_count: Union[int, float], + avg_count: Union[int, float], + base_prob: float, + increase_factor: float = 0.5 +) -> float: + """ + Adjusts the base probability based on observed and average counts using a scaling factor. + + Parameters + ---------- + obs_count : int or float + The observed count of an event or observation in the dataset. + + avg_count : int or float + The average count of the event or observation across the dataset. + + base_prob : float + The base probability of the event or observation occurring. + + increase_factor : float, optional, default=0.5 + A scaling factor applied to adjust the probability when `obs_count` is below `avg_count`. + This factor influences how much to increase or decrease the probability. + + Returns + ------- + float + The adjusted probability, scaled based on the ratio between the observed count and the average count. + The adjusted probability will be within the range [0, 1]. + + Notes + ----- + This function adjusts a base probability based on the observed count (`obs_count`) compared to the average count + (`avg_count`). If the observed count is lower than the average, the probability is increased proportionally, + but capped at a maximum of 1.0. Conversely, if the observed count exceeds the average, the probability is reduced, + but not below 0. The `increase_factor` controls the sensitivity of the probability adjustment when the observed + count is less than the average count. + """ + if obs_count < avg_count: + # Increase probability when observed count is lower than average count + return min(base_prob * (avg_count / obs_count) * increase_factor, 1.0) + else: + # Decrease probability when observed count exceeds average count + return max(base_prob * (obs_count / avg_count) / increase_factor, 0.0) + +def non_uniform_sample( + data, + removal_percent, + pre_replacement_probabilities=None, + increase_factor=0.5 + ): + + """ + Process time-series data by randomly removing a certain percentage of observed values based on pre-defined + replacement probabilities, and compute the necessary features such as forward and backward deltas, masks, + and last observed values. + + This function generates records for each time series and returns them as PyTorch tensors for further usage. + + Parameters + ---------- + data : np.ndarray + The input data with shape [N, T, D], where N is the number of samples, T is the number of time steps, + and D is the number of features. Missing values should be indicated with NaNs. + + removal_percent : float + The percentage of observed values to be removed randomly from the dataset. + + pre_replacement_probabilities : np.ndarray, optional + Pre-defined replacement probabilities for each feature. If provided, this will be used to determine + which values to remove. + + increase_factor : float, default=0.5 + A factor to adjust replacement probabilities based on the observation count for each feature. + + Returns + ------- + tensor_dict : dict of torch.Tensors + A dictionary of PyTorch tensors including 'values', 'last_obs_f', 'last_obs_b', 'masks', 'deltas_f', + 'deltas_b', 'evals', and 'eval_masks'. + + replacement_probabilities : np.ndarray + The computed or provided replacement probabilities for each feature. + """ + # Get dimensionality + [N, T, D] = data.shape + + # Compute replacement probabilities if not provided + if pre_replacement_probabilities is None: + observations_per_feature = np.sum(~np.isnan(data), axis=(0, 1)) + average_observations = np.mean(observations_per_feature) + replacement_probabilities = np.full(D, removal_percent / 100) + + if increase_factor > 0: + for feature_idx in range(D): + replacement_probabilities[feature_idx] = adjust_probability_vectorized( + observations_per_feature[feature_idx], + average_observations, + replacement_probabilities[feature_idx], + increase_factor=increase_factor + ) + + total_observations = np.sum(observations_per_feature) + total_replacement_target = total_observations * removal_percent / 100 + + for _ in range(1000): # Limit iterations to prevent infinite loop + total_replacement = np.sum(replacement_probabilities * observations_per_feature) + if np.isclose(total_replacement, total_replacement_target, rtol=1e-3): + break + adjustment_factor = total_replacement_target / total_replacement + replacement_probabilities *= adjustment_factor + else: + replacement_probabilities = pre_replacement_probabilities + + # Prepare data structures + recs = [] + values = copy.deepcopy(data) + + # Randomly remove data points based on replacement probabilities + random_matrix = np.random.rand(N, T, D) + values[(~np.isnan(values)) & (random_matrix < replacement_probabilities)] = np.nan + + # Generate records and features for each sample + for i in range(N): + masks = ~np.isnan(values[i, :, :]) + eval_masks = (~np.isnan(values[i, :, :])) ^ (~np.isnan(data[i, :, :])) + evals = data[i, :, :] + + # Compute forward and backward deltas + deltas_f = parse_delta(masks) + deltas_b = parse_delta(masks[::-1, :]) + + # Compute last observations for forward and backward directions + last_obs_f = compute_last_obs(values[i, :, :], masks) + last_obs_b = compute_last_obs(values[i, ::-1, :], masks[::-1, :]) + + # Append the record for this sample + recs.append({ + 'values': np.nan_to_num(values[i, :, :]), + 'last_obs_f': np.nan_to_num(last_obs_f), + 'last_obs_b': np.nan_to_num(last_obs_b), + 'masks': masks.astype('int32'), + 'evals': np.nan_to_num(evals), + 'eval_masks': eval_masks.astype('int32'), + 'deltas_f': deltas_f, + 'deltas_b': deltas_b + }) + + # Convert records to PyTorch tensors + tensor_dict = { + 'values': torch.FloatTensor(np.array([r['values'] for r in recs])), + 'last_obs_f': torch.FloatTensor(np.array([r['last_obs_f'] for r in recs])), + 'last_obs_b': torch.FloatTensor(np.array([r['last_obs_b'] for r in recs])), + 'masks': torch.FloatTensor(np.array([r['masks'] for r in recs])), + 'deltas_f': torch.FloatTensor(np.array([r['deltas_f'] for r in recs])), + 'deltas_b': torch.FloatTensor(np.array([r['deltas_b'] for r in recs])), + 'evals': torch.FloatTensor(np.array([r['evals'] for r in recs])), + 'eval_masks': torch.FloatTensor(np.array([r['eval_masks'] for r in recs])) + } + + return tensor_dict, replacement_probabilities + + +class DatasetForCSAI(BaseDataset): + """" + Parameters + ---------- + data : + The dataset for model input, which can be either a dictionary or a path string to a data file. If it's a dictionary, `X` should be an array-like structure with shape [n_samples, sequence length (n_steps), n_features], containing the time-series data, and it can have missing values. Optionally, the dictionary can include `y`, an array-like structure with shape [n_samples], representing the labels of `X`. If `data` is a path string, it should point to a data file (e.g., h5 file) that contains key-value pairs like a dictionary, including keys for `X` and possibly `y`. + + return_X_ori : + Whether to return the original time-series data (`X_ori`) when fetching data samples, useful for evaluation purposes. + + return_y : + Whether to return classification labels in the `__getitem__()` method if they exist in the dataset. If `True`, labels will be included in the returned data samples, which is useful for training classification models. If `False`, the labels won't be returned, suitable for testing or validation stages. + + file_type : + The type of the data file if `data` is a path string, such as "hdf5". + + removal_percent : + The percentage of data to be removed for simulating missing values during training. + + increase_factor : + A scaling factor to increase the probability of missing data during training. + + compute_intervals : + Whether to compute time intervals between observations for handling irregular time-series data. + + replacement_probabilities : + Optional precomputed probabilities for sampling missing values. If not provided, they will be calculated during the initialization of the dataset. + + normalise_mean : + A list of mean values for normalizing the input features. If not provided, they will be computed during initialization. + + normalise_std : + A list of standard deviation values for normalizing the input features. If not provided, they will be computed during initialization. + + training : + Whether the dataset is used for training. If `False`, it will adjust how data is processed, particularly for evaluation and testing phases. + + Notes + ----- + The DatasetForCSAI class is designed for bidirectional imputation of time-series data, handling both forward and backward directions to improve imputation accuracy. It supports on-the-fly data normalization and missing value simulation, making it suitable for training and evaluating deep learning models like CSAI. The class can work with large datasets stored on disk, leveraging lazy-loading to minimize memory usage, and supports both training and testing scenarios, adjusting data handling as needed. + + """ + def __init__(self, + data: Union[dict, str], + return_X_ori: bool, + return_y: bool, + file_type: str = "hdf5", + removal_percent: float = 0.0, + increase_factor: float = 0.1, + compute_intervals: bool = False, + replacement_probabilities = None, + normalise_mean : list = [], + normalise_std: list = [], + training: bool = True + ): + super().__init__(data = data, + return_X_ori = return_X_ori, + return_X_pred = False, + return_y = return_y, + file_type = file_type) + + self.removal_percent = removal_percent + self.increase_factor = increase_factor + self.compute_intervals = compute_intervals + self.replacement_probabilities = replacement_probabilities + self.normalise_mean = normalise_mean + self.normalise_std = normalise_std + self.training = training + + if not isinstance(self.data, str): + self.normalized_data, self.mean_set, self.std_set, self.intervals = normalize_csai( + self.data['X'], + self.normalise_mean, + self.normalise_std, + compute_intervals, + ) + + self.processed_data, self.replacement_probabilities = non_uniform_sample( + self.normalized_data, + removal_percent, + replacement_probabilities, + increase_factor, + ) + self.forward_X = self.processed_data['values'] + self.forward_missing_mask = self.processed_data['masks'] + self.backward_X = torch.flip(self.forward_X, dims=[1]) + self.backward_missing_mask = torch.flip(self.forward_missing_mask, dims=[1]) + + self.X_ori = self.processed_data['evals'] + self.indicating_mask = self.processed_data['eval_masks'] + + + def _fetch_data_from_array(self, idx: int) -> Iterable: + """Fetch data from self.X if it is given. + + Parameters + ---------- + idx : + The index of the sample to be return. + + Returns + ------- + sample : + A list contains + + index : int tensor, + The index of the sample. + + X : tensor, + The feature vector for model input. + + missing_mask : tensor, + The mask indicates all missing values in X. + + delta : tensor, + The delta matrix contains time gaps of missing values. + + label (optional) : tensor, + The target label of the time-series sample. + """ + + + sample = [ + torch.tensor(idx), + # for forward + self.forward_X[idx], + self.forward_missing_mask[idx], + self.processed_data["deltas_f"][idx], + self.processed_data["last_obs_f"][idx], + # for backward + self.backward_X[idx], + self.backward_missing_mask[idx], + self.processed_data["deltas_b"][idx], + self.processed_data["last_obs_b"][idx], + ] + + if not self.training: + sample.extend([self.X_ori[idx], self.indicating_mask[idx]]) + + if self.return_y: + sample.append(self.y[idx].to(torch.long)) + + return { + 'sample': sample, + 'replacement_probabilities': self.replacement_probabilities, + 'mean_set': self.mean_set, + 'std_set': self.std_set, + 'intervals': self.intervals + } + + def _fetch_data_from_file(self, idx: int) -> Iterable: + """Fetch data with the lazy-loading strategy, i.e. only loading data from the file while requesting for samples. + Here the opened file handle doesn't load the entire dataset into RAM but only load the currently accessed slice. + + Parameters + ---------- + idx : + The index of the sample to be return. + + Returns + ------- + sample : + The collated data sample, a list including all necessary sample info. + """ + + if self.file_handle is None: + self.file_handle = self._open_file_handle() + + X = torch.from_numpy(self.file_handle["X"][idx]) + normalized_data, mean_set, std_set, intervals = normalize_csai( + X, + self.normalise_mean, + self.normalise_std, + self.compute_intervals, + ) + + processed_data, replacement_probabilities = non_uniform_sample( + normalized_data, + self.removal_percent, + self.replacement_probabilities, + self.increase_factor, + ) + forward_X = processed_data['values'] + forward_missing_mask = processed_data['masks'] + backward_X = torch.flip(forward_X, dims=[1]) + backward_missing_mask = torch.flip(forward_missing_mask, dims=[1]) + + X_ori = self.processed_data['evals'] + indicating_mask = self.processed_data['eval_masks'] + + if self.return_y: + y = self.processed_data['labels'] + + sample = [ + torch.tensor(idx), + # for forward + forward_X, + forward_missing_mask, + processed_data["deltas_f"], + processed_data["last_obs_f"], + # for backward + backward_X, + backward_missing_mask, + processed_data["deltas_b"], + processed_data["last_obs_b"] + ] + + if self.return_X_ori: + sample.extend([X_ori, indicating_mask]) + + # if the dataset has labels and is for training, then fetch it from the file + if self.return_y: + sample.append(y) + + return { + 'sample': sample, + 'replacement_probabilities': replacement_probabilities, + 'mean_set': mean_set, + 'std_set': std_set, + 'intervals': intervals + } + diff --git a/pypots/imputation/csai/model.py b/pypots/imputation/csai/model.py new file mode 100644 index 00000000..b61286c5 --- /dev/null +++ b/pypots/imputation/csai/model.py @@ -0,0 +1,369 @@ +""" +The implementation of CSAI +""" + +# Created by Linglong Qian, Joseph Arul Raj +# License: BSD-3-Clause + +from typing import Union, Optional + +import numpy as np +import torch +from torch.utils.data import DataLoader + + +from .core import _BCSAI +from .data import DatasetForCSAI +from ..base import BaseNNImputer +from ...data.checking import key_in_data_set +from ...optim.adam import Adam +from ...optim.base import Optimizer + + +class CSAI(BaseNNImputer): + """ + Parameters + ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + rnn_hidden_size : + The size of the GRU hidden state, also the number of hidden units in the GRU cell. + + imputation_weight : + The weight assigned to the reconstruction loss during training. + + consistency_weight : + The weight assigned to the consistency loss during training. + + removal_percent : + The percentage of data to be removed during training for imputation tasks. + + increase_factor : + A scaling factor used to adjust the amount of missing data during training. + + compute_intervals : + Whether to compute time intervals between observations for handling irregular time-series. + + step_channels : + The number of channels for each step in the sequence. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, training will stop when no improvement is observed after the specified number of epochs. If set to None, early-stopping is disabled. + + optimizer : + The optimizer used for model training. Defaults to the Adam optimizer if not specified. + + num_workers : + The number of subprocesses used for data loading. Setting this to `0` means that data loading is performed in the main process without using subprocesses. + + device : + The device for the model to run on, which can be a string, a :class:`torch.device` object, or a list of devices. If not provided, the model will attempt to use available CUDA devices first, then default to CPUs. + + saving_path : + The path for saving model checkpoints and tensorboard files during training. If not provided, models will not be saved automatically. + + model_saving_strategy : + The strategy for saving model checkpoints. Can be one of [None, "best", "better", "all"]. "best" saves the best model after training, "better" saves any model that improves during training, and "all" saves models after each epoch. If set to None, no models will be saved. + + verbose : + Whether to print training logs during the training process. + + Notes + ----- + CSAI (Consistent Sequential Imputation) is a bidirectional model designed for time-series imputation. It employs a forward and backward GRU network to handle missing data, using consistency and reconstruction losses to improve accuracy. The model supports various training configurations, such as interval computations, early-stopping, and multiple devices for training. Results can be saved based on the specified saving strategy, and tensorboard files are generated for tracking the model's performance over time. + + """ + + def __init__(self, + n_steps: int, + n_features: int, + rnn_hidden_size: int, + imputation_weight: float, + consistency_weight: float, + removal_percent: int, + increase_factor: float, + compute_intervals: bool, + step_channels:int, + batch_size: int, + epochs: int, + patience: Union[int, None ]= None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Union[str, torch.device, list, None ]= None, + saving_path: str = None, + model_saving_strategy: Union[str, None] = "best", + verbose: bool = True, + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + verbose, + ) + + self.n_steps = n_steps + self.n_features = n_features + self.rnn_hidden_size = rnn_hidden_size + self.imputation_weight = imputation_weight + self.consistency_weight = consistency_weight + self.removal_percent = removal_percent + self.increase_factor = increase_factor + self.step_channels = step_channels + self.compute_intervals = compute_intervals + self.intervals = None + + # Initialise model + self.model = _BCSAI( + self.n_steps, + self.n_features, + self.rnn_hidden_size, + self.step_channels, + self.consistency_weight, + self.imputation_weight, + self.intervals, + ) + + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + + def _assemble_input_for_training(self, data: list, training=True) -> dict: + # extract data + sample = data['sample'] + + ( + indices, + X, + missing_mask, + deltas, + last_obs, + back_X, + back_missing_mask, + back_deltas, + back_last_obs + ) = self._send_data_to_given_device(sample) + + # assemble input data + inputs = { + "indices": indices, + "forward": { + "X": X, + "missing_mask": missing_mask, + "deltas": deltas, + "last_obs": last_obs, + }, + "backward": { + "X": back_X, + "missing_mask": back_missing_mask, + "deltas": back_deltas, + "last_obs": back_last_obs, + }, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + # extract data + sample = data['sample'] + ( + indices, + X, + missing_mask, + deltas, + last_obs, + back_X, + back_missing_mask, + back_deltas, + back_last_obs, + X_ori, + indicating_mask, + ) = self._send_data_to_given_device(sample) + + # assemble input data + inputs = { + "indices": indices, + "forward": { + "X": X, + "missing_mask": missing_mask, + "deltas": deltas, + "last_obs": last_obs, + }, + "backward": { + "X": back_X, + "missing_mask": back_missing_mask, + "deltas": back_deltas, + "last_obs": back_last_obs, + }, + "X_ori": X_ori, + "indicating_mask": indicating_mask, + } + return inputs + + def _assemble_input_for_testing(self, data: list) -> dict: + return self._assemble_input_for_validating(data) + + def fit( + self, + train_set, + val_set=None, + file_type: str = "hdf5", + )-> None: + + self.training_set = DatasetForCSAI( + train_set, + False, + False, + file_type, + self.removal_percent, + self.increase_factor, + self.compute_intervals + ) + self.intervals = self.training_set.intervals + self.replacement_probabilities = self.training_set.replacement_probabilities + self.mean_set = self.training_set.mean_set + self.std_set = self.training_set.std_set + + training_loader = DataLoader( + self.training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + # collate_fn=collate_fn_bidirectional + ) + if val_set is not None: + val_set = DatasetForCSAI( + val_set, + True, + False, + file_type, + self.removal_percent, + self.increase_factor, + self.compute_intervals, + self.replacement_probabilities, + self.mean_set, + self.std_set, + False, + ) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + # collate_fn=collate_fn_bidirectional + ) + + # Reset the model + self.model = _BCSAI( + self.n_steps, + self.n_features, + self.rnn_hidden_size, + self.step_channels, + self.consistency_weight, + self.imputation_weight, + self.intervals, + ) + + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + # train the model + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> dict: + + self.model.eval() + test_set = DatasetForCSAI( + test_set, + True, + False, + file_type, + self.removal_percent, + self.increase_factor, + self.compute_intervals, + self.replacement_probabilities, + self.mean_set, + self.std_set, + False, + ) + + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + # collate_fn=collate_fn_bidirectional + ) + + imputation_collector = [] + x_ori_collector = [] + indicating_mask_collector = [] + + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputed_data = results["imputed_data"] + imputation_collector.append(imputed_data) + x_ori_collector.append(inputs["X_ori"]) + indicating_mask_collector.append(inputs["indicating_mask"]) + + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + "X_ori": torch.cat(x_ori_collector).cpu().detach().numpy(), + "indicating_mask": torch.cat(indicating_mask_collector).cpu().detach().numpy(), + } + return result_dict + + def impute( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Parameters + ---------- + test_set : + The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (n_steps), n_features], + Imputed data. + """ + + result_dict = self.predict(test_set, file_type=file_type) + return result_dict["imputation"] diff --git a/pypots/nn/modules/csai/__init__.py b/pypots/nn/modules/csai/__init__.py new file mode 100644 index 00000000..64c57392 --- /dev/null +++ b/pypots/nn/modules/csai/__init__.py @@ -0,0 +1,31 @@ +""" +The package including the modules of CSAI. + +Refer to the paper +`Linglong Qian, Zina Ibrahim, Hugh Logan Ellis, Ao Zhang, Yuezhou Zhang, Tao Wang, Richard Dobson. +Knowledge Enhanced Conditional Imputation for Healthcare Time-series. +In Arxiv, 2024. +`_ + +Notes +----- +This implementation is inspired by the official one the official implementation https://github.com/LinglongQian/CSAI. + +""" + +# Created by Joseph Arul Raj +# License: BSD-3-Clause + +from .backbone import BackboneCSAI, BackboneBCSAI +from .layers import FeatureRegression, Decay, Decay_obs, PositionalEncoding, Conv1dWithInit, TorchTransformerEncoder + +__all__ = [ + "BackboneCSAI", + "BackboneBCSAI", + "FeatureRegression", + "Decay", + "Decay_obs", + "PositionalEncoding", + "Conv1dWithInit", + "TorchTransformerEncoder" +] diff --git a/pypots/nn/modules/csai/backbone.py b/pypots/nn/modules/csai/backbone.py new file mode 100644 index 00000000..57600db9 --- /dev/null +++ b/pypots/nn/modules/csai/backbone.py @@ -0,0 +1,245 @@ +""" + +""" + +# Created by Linglong Qian, Joseph Arul Raj +# License: BSD-3-Clause + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from .layers import FeatureRegression, Decay, Decay_obs, PositionalEncoding, Conv1dWithInit, TorchTransformerEncoder +from ....utils.metrics import calc_mae + +class BackboneCSAI(nn.Module): + """ + Attributes + ---------- + n_steps : + sequence length (number of time steps) + + n_features : + number of features (input dimensions) + + rnn_hidden_size : + the hidden size of the GRU cell + + step_channels : + number of channels for each step in the sequence + + medians_tensor : + tensor of median values for features, used to adjust decayed observations + + temp_decay_h : + the temporal decay module to decay the hidden state of the GRU + + temp_decay_x : + the temporal decay module to decay data in the raw feature space + + hist : + the temporal-regression module that projects the GRU hidden state into the raw feature space + + feat_reg_v : + the feature-regression module used for feature-based estimation + + weight_combine : + the module that generates the weight to combine history regression and feature regression + + weighted_obs : + the decay module that computes weighted decay based on observed data and deltas + + gru : + the GRU cell that models temporal data for imputation + + pos_encoder : + the positional encoding module that adds temporal information to the sequence data + + input_projection : + the convolutional module used to project input features into a higher-dimensional space + + output_projection1 : + the convolutional module used to project the output from the Transformer layer + + output_projection2 : + the final convolutional module used to generate the hidden state from the time-layer's output + + time_layer : + the Transformer encoder layer used to model complex temporal dependencies within the sequence + + device : + the device (CPU/GPU) used for model computations + + Parameters + ---------- + n_steps : + sequence length (number of time steps) + + n_features : + number of features (input dimensions) + + rnn_hidden_size : + the hidden size of the GRU cell + + step_channels : + number of channels for each step in the sequence + + medians_df : + dataframe of median values for each feature, optional + + """ + + def __init__(self, n_steps, n_features, rnn_hidden_size, step_channels, medians_df=None): + super(BackboneCSAI, self).__init__() + + if medians_df is not None: + self.medians_tensor = torch.tensor(list(medians_df.values())).float() + else: + self.medians_tensor = torch.zeros(n_features).float() + + self.n_steps = n_steps + self.step_channels = step_channels + self.input_size = n_features + self.hidden_size = rnn_hidden_size + self.temp_decay_h = Decay(input_size=self.input_size, output_size=self.hidden_size, diag = False) + self.temp_decay_x = Decay(input_size=self.input_size, output_size=self.input_size, diag = True) + self.hist = nn.Linear(self.hidden_size, self.input_size) + self.feat_reg_v = FeatureRegression(self.input_size) + self.weight_combine = nn.Linear(self.input_size * 2, self.input_size) + self.weighted_obs = Decay_obs(self.input_size, self.input_size) + self.gru = nn.GRUCell(self.input_size * 2, self.hidden_size) + + self.pos_encoder = PositionalEncoding(self.step_channels) + self.input_projection = Conv1dWithInit(self.input_size, self.step_channels, 1) + self.output_projection1 = Conv1dWithInit(self.step_channels, self.hidden_size, 1) + self.output_projection2 = Conv1dWithInit(self.n_steps*2, 1, 1) + self.time_layer = TorchTransformerEncoder(channels=self.step_channels) + + self.reset_parameters() + + def reset_parameters(self): + for weight in self.parameters(): + if len(weight.size()) == 1: + continue + stv = 1. / math.sqrt(weight.size(1)) + nn.init.uniform_(weight, -stv, stv) + + def forward(self, x, mask, deltas, last_obs, h=None): + + # Get dimensionality + [B, _, _] = x.shape + + medians = self.medians_tensor.unsqueeze(0).repeat(B, 1).to(x.device) + + decay_factor = self.weighted_obs(deltas - medians.unsqueeze(1)) + + if h == None: + data_last_obs = self.input_projection(last_obs.permute(0, 2, 1)).permute(0, 2, 1) + data_decay_factor = self.input_projection(decay_factor.permute(0, 2, 1)).permute(0, 2, 1) + + data_last_obs = self.pos_encoder(data_last_obs.permute(1, 0, 2)).permute(1, 0, 2) + data_decay_factor = self.pos_encoder(data_decay_factor.permute(1, 0, 2)).permute(1, 0, 2) + + data = torch.cat([data_last_obs, data_decay_factor], dim=1) + + data = self.time_layer(data) + data = self.output_projection1(data.permute(0, 2, 1)).permute(0, 2, 1) + h = self.output_projection2(data).squeeze() + + x_loss = 0 + x_imp = x.clone() + Hiddens = [] + reconstruction = [] + for t in range(self.n_steps): + x_t = x[:, t, :] + d_t = deltas[:, t, :] + m_t = mask[:, t, :] + + # Decayed Hidden States + gamma_h = self.temp_decay_h(d_t) + h = h * gamma_h + + # history based estimation + x_h = self.hist(h) + + x_r_t = (m_t * x_t) + ((1 - m_t) * x_h) + + # feature based estimation + xu = self.feat_reg_v(x_r_t) + gamma_x = self.temp_decay_x(d_t) + + beta = self.weight_combine(torch.cat([gamma_x, m_t], dim=1)) + x_comb_t = beta * xu + (1 - beta) * x_h + + # x_loss += torch.sum(torch.abs(x_t - x_comb_t) * m_t) / (torch.sum(m_t) + 1e-5) + x_loss += calc_mae(x_comb_t, x_t, m_t) + + # Final Imputation Estimates + x_imp[:, t, :] = (m_t * x_t) + ((1 - m_t) * x_comb_t) + + # Set input the RNN + input_t = torch.cat([x_imp[:, t, :], m_t], dim=1) + + h = self.gru(input_t, h) + Hiddens.append(h.unsqueeze(dim=1)) + reconstruction.append(x_comb_t.unsqueeze(dim=1)) + Hiddens = torch.cat(Hiddens, dim=1) + + return x_imp, reconstruction, h, x_loss + + +class BackboneBCSAI(nn.Module): + def __init__(self, n_steps, n_features, rnn_hidden_size, step_channels, medians_df=None): + super(BackboneBCSAI, self).__init__() + + self.model_f = BackboneCSAI(n_steps, n_features, rnn_hidden_size, step_channels, medians_df) + self.model_b = BackboneCSAI(n_steps, n_features, rnn_hidden_size, step_channels, medians_df) + + def forward(self, xdata): + + # Fetching forward data from xdata + x = xdata['forward']['X'] + m = xdata['forward']['missing_mask'] + d_f = xdata['forward']['deltas'] + last_obs_f = xdata['forward']['last_obs'] + + # Fetching backward data from xdata + x_b = xdata['backward']['X'] + m_b = xdata['backward']['missing_mask'] + d_b = xdata['backward']['deltas'] + last_obs_b = xdata['backward']['last_obs'] + + # Call forward model + ( + f_imputed_data, + f_reconstruction, + f_hidden_states, + f_reconstruction_loss, + ) = self.model_f(x, m, d_f, last_obs_f) + + # Call backward model + ( + b_imputed_data, + b_reconstruction, + b_hidden_states, + b_reconstruction_loss, + ) = self.model_b(x_b, m_b, d_b, last_obs_b) + + # Averaging the imputations and prediction + x_imp = (f_imputed_data + b_imputed_data.flip(dims=[1])) / 2 + imputed_data = (x * m)+ ((1-m) * x_imp) + + # average consistency loss + consistency_loss = torch.abs(f_imputed_data - b_imputed_data.flip(dims=[1])).mean() * 1e-1 + + # Merge the regression loss + reconstruction_loss = f_reconstruction_loss + b_reconstruction_loss + return ( + imputed_data, + f_reconstruction, + b_reconstruction, + f_hidden_states, + b_hidden_states, + consistency_loss, + reconstruction_loss, + ) diff --git a/pypots/nn/modules/csai/layers.py b/pypots/nn/modules/csai/layers.py new file mode 100644 index 00000000..d603eef1 --- /dev/null +++ b/pypots/nn/modules/csai/layers.py @@ -0,0 +1,135 @@ +""" + +""" + +# Created by Joseph Arul Raj +# License: BSD-3-Clause + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from torch.nn.parameter import Parameter +import math +import numpy as np +import os +import copy +import pandas as pd +from torch.nn.modules import TransformerEncoderLayer + + +class FeatureRegression(nn.Module): + def __init__(self, input_size): + super(FeatureRegression, self).__init__() + self.build(input_size) + + def build(self, input_size): + self.W = Parameter(torch.Tensor(input_size, input_size)) + self.b = Parameter(torch.Tensor(input_size)) + m = torch.ones(input_size, input_size) - torch.eye(input_size, input_size) + self.register_buffer('m', m) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1. / math.sqrt(self.W.size(0)) + self.W.data.uniform_(-stdv, stdv) + if self.b is not None: + self.b.data.uniform_(-stdv, stdv) + + def forward(self, x): + z_h = F.linear(x, self.W * Variable(self.m), self.b) + return z_h + +class Decay(nn.Module): + def __init__(self, input_size, output_size, diag=False): + super(Decay, self).__init__() + self.diag = diag + self.build(input_size, output_size) + + def build(self, input_size, output_size): + self.W = Parameter(torch.Tensor(output_size, input_size)) + self.b = Parameter(torch.Tensor(output_size)) + + if self.diag == True: + assert(input_size == output_size) + m = torch.eye(input_size, input_size) + self.register_buffer('m', m) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1. / math.sqrt(self.W.size(0)) + self.W.data.uniform_(-stdv, stdv) + if self.b is not None: + self.b.data.uniform_(-stdv, stdv) + + def forward(self, d): + if self.diag == True: + gamma = F.relu(F.linear(d, self.W * Variable(self.m), self.b)) + else: + gamma = F.relu(F.linear(d, self.W, self.b)) + gamma = torch.exp(-gamma) + return gamma + +class Decay_obs(nn.Module): + def __init__(self, input_size, output_size): + super(Decay_obs, self).__init__() + self.linear = nn.Linear(input_size, output_size) + + def forward(self, delta_diff): + # When delta_diff is negative, weight tends to 1. + # When delta_diff is positive, weight tends to 0. + sign = torch.sign(delta_diff) + weight_diff = self.linear(delta_diff) + # weight_diff can be either positive or negative for each delta_diff + positive_part = F.relu(weight_diff) + negative_part = F.relu(-weight_diff) + weight_diff = positive_part + negative_part + weight_diff = sign * weight_diff + # Using a tanh activation to squeeze values between -1 and 1 + weight_diff = torch.tanh(weight_diff) + # This will move the weight values towards 1 if delta_diff is negative + # and towards 0 if delta_diff is positive + weight = 0.5 * (1 - weight_diff) + + return weight + +class TorchTransformerEncoder(nn.Module): + def __init__(self, heads=8, layers=1, channels=64): + super(TorchTransformerEncoder, self).__init__() + self.encoder_layer = nn.TransformerEncoderLayer( + d_model=channels, nhead=heads, dim_feedforward=64, activation="gelu" + ) + self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=layers) + + def forward(self, x): + return self.transformer_encoder(x) + +class Conv1dWithInit(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size): + super(Conv1dWithInit, self).__init__() + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size) + nn.init.kaiming_normal_(self.conv.weight) + + def forward(self, x): + return self.conv(x) + +class PositionalEncoding(nn.Module): + + def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + pe = torch.zeros(max_len, 1, d_model) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + def forward(self, x): + """ + Arguments: + x: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` + """ + x = x + self.pe[:x.size(0)] + return self.dropout(x) \ No newline at end of file