Skip to content

Commit

Permalink
CSAI Pipeline (#534)
Browse files Browse the repository at this point in the history
* Add CSAI implementation to pypots

---------

Co-authored-by: Joseph Arul Raj Patterson Kulandai Raj <[email protected]>
Co-authored-by: LINGLONGQIAN <[email protected]>
Co-authored-by: Joseph Arul Raj Patterson Kulandai Raj <[email protected]>
  • Loading branch information
4 people authored Oct 8, 2024
1 parent 83644ab commit 6c5777e
Show file tree
Hide file tree
Showing 13 changed files with 1,984 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pypots/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
# License: BSD-3-Clause

from .brits import BRITS
from .csai import CSAI
from .grud import GRUD
from .raindrop import Raindrop

__all__ = [
"CSAI",
"BRITS",
"GRUD",
"Raindrop",
Expand Down
20 changes: 20 additions & 0 deletions pypots/classification/csai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
The package including the modules of CSAI.
Refer to the paper
`Linglong Qian, Zina Ibrahim, Hugh Logan Ellis, Ao Zhang, Yuezhou Zhang, Tao Wang, Richard Dobson.
Knowledge Enhanced Conditional Imputation for Healthcare Time-series.
In Arxiv, 2024.
<https://arxiv.org/abs/2312.16713>`_
Notes
-----
This implementation is inspired by the official one the official implementation https://github.com/LinglongQian/CSAI.
"""

from .model import CSAI

__all__ = [
"CSAI",
]
123 changes: 123 additions & 0 deletions pypots/classification/csai/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
"""

# Created by Linglong Qian, Joseph Arul Raj <[email protected], [email protected]>
# License: BSD-3-Clause

import torch
import torch.nn as nn
import torch.nn.functional as F

from ...nn.modules.csai import BackboneBCSAI

# class DiceBCELoss(nn.Module):
# def __init__(self, weight=None, size_average=True):
# super(DiceBCELoss, self).__init__()
# self.bcelogits = nn.BCEWithLogitsLoss()

# def forward(self, y_score, y_out, targets, smooth=1):

# #comment out if your model contains a sigmoid or equivalent activation layer
# # inputs = F.sigmoid(inputs)

# #flatten label and prediction tensors
# BCE = self.bcelogits(y_out, targets)

# y_score = y_score.view(-1)
# targets = targets.view(-1)
# intersection = (y_score * targets).sum()
# dice_loss = 1 - (2.*intersection + smooth)/(y_score.sum() + targets.sum() + smooth)

# Dice_BCE = BCE + dice_loss

# return BCE, Dice_BCE


class _BCSAI(nn.Module):
def __init__(
self,
n_steps: int,
n_features: int,
rnn_hidden_size: int,
imputation_weight: float,
consistency_weight: float,
classification_weight: float,
n_classes: int,
step_channels: int,
dropout: float = 0.5,
intervals=None,
):
super().__init__()
self.n_steps = n_steps
self.n_features = n_features
self.rnn_hidden_size = rnn_hidden_size
self.imputation_weight = imputation_weight
self.consistency_weight = consistency_weight
self.classification_weight = classification_weight
self.n_classes = n_classes
self.step_channels = step_channels
self.intervals = intervals

# create models
self.model = BackboneBCSAI(n_steps, n_features, rnn_hidden_size, step_channels, intervals)
self.f_classifier = nn.Linear(self.rnn_hidden_size, n_classes)
self.b_classifier = nn.Linear(self.rnn_hidden_size, n_classes)
self.imputer = nn.Linear(self.rnn_hidden_size, n_features)
self.dropout = nn.Dropout(dropout)

def forward(self, inputs: dict, training: bool = True) -> dict:

(
imputed_data,
f_reconstruction,
b_reconstruction,
f_hidden_states,
b_hidden_states,
consistency_loss,
reconstruction_loss,
) = self.model(inputs)

results = {
"imputed_data": imputed_data,
}

f_logits = self.f_classifier(self.dropout(f_hidden_states))
b_logits = self.b_classifier(self.dropout(b_hidden_states))

# f_prediction = torch.sigmoid(f_logits)
# b_prediction = torch.sigmoid(b_logits)

f_prediction = torch.softmax(f_logits, dim=1)
b_prediction = torch.softmax(b_logits, dim=1)
classification_pred = (f_prediction + b_prediction) / 2

results = {
"imputed_data": imputed_data,
"classification_pred": classification_pred,
}

# if in training mode, return results with losses
if training:
# criterion = DiceBCELoss().to(imputed_data.device)
results["consistency_loss"] = consistency_loss
results["reconstruction_loss"] = reconstruction_loss
# print(inputs["labels"].unsqueeze(1))
f_classification_loss = F.nll_loss(torch.log(f_prediction), inputs["labels"])
b_classification_loss = F.nll_loss(torch.log(b_prediction), inputs["labels"])
# f_classification_loss, _ = criterion(f_prediction, f_logits, inputs["labels"].unsqueeze(1).float())
# b_classification_loss, _ = criterion(b_prediction, b_logits, inputs["labels"].unsqueeze(1).float())
classification_loss = (f_classification_loss + b_classification_loss)

loss = (
self.consistency_weight * consistency_loss +
self.imputation_weight * reconstruction_loss +
self.classification_weight * classification_loss
)

results["loss"] = loss
results["classification_loss"] = classification_loss
results["f_reconstruction"] = f_reconstruction
results["b_reconstruction"] = b_reconstruction

return results
39 changes: 39 additions & 0 deletions pypots/classification/csai/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
"""

# Created by Joseph Arul Raj <[email protected]>
# License: BSD-3-Clause

from typing import Union
from ...imputation.csai.data import DatasetForCSAI as DatasetForCSAI_Imputation



class DatasetForCSAI(DatasetForCSAI_Imputation):
def __init__(self,
data: Union[dict, str],
file_type: str = "hdf5",
return_y: bool = True,
removal_percent: float = 0.0,
increase_factor: float = 0.1,
compute_intervals: bool = False,
replacement_probabilities = None,
normalise_mean : list = [],
normalise_std: list = [],
training: bool = True
):
super().__init__(
data=data,
return_X_ori=False,
return_y=return_y,
file_type=file_type,
removal_percent=removal_percent,
increase_factor=increase_factor,
compute_intervals=compute_intervals,
replacement_probabilities=replacement_probabilities,
normalise_mean=normalise_mean,
normalise_std=normalise_std,
training=training
)

Loading

0 comments on commit 6c5777e

Please sign in to comment.