-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 3fde806
Showing
44 changed files
with
1,396 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Semi-supervised learning with data harmonisation for biomarker discovery from resting state fMRI | ||
|
||
Code for SHRED-II is provided in this repository. | ||
|
||
Models were developed in Python using Pytorch v1.8.0. The experiments were performed on an Nvidia P100 GPU. | ||
- Information on sensitivity regarding parameter changes can be found in `./figures/`. | ||
- Expected runtime: Depending on the size of the chosen site, a full run of 10 seeds x 5 folds for a single site can take between 1 to 3 minutes. | ||
- Memory footprint: ~4GB of GPU memory | ||
|
||
Data is available for download at the following links: [ABIDE](http://preprocessed-connectomes-project.org/abide/), [ADHD](http://preprocessed-connectomes-project.org/adhd200/). | ||
|
||
|
||
## Environment Setup | ||
|
||
1. Create and activate new Anaconda environment | ||
|
||
conda create -n <env_name> python=3.8 | ||
conda activate <env_name> | ||
|
||
2. Run ``setup.sh`` | ||
|
||
chmod u+x ./setup.sh | ||
./setup.sh | ||
|
||
## Dataset Preparation | ||
|
||
1. Process the raw fMRI data to obtain a functional connectivity matrix for each subject scan. The functional connectivity matrices should be saved as ``.npy`` files ([example](dataset/ABIDE/processed_corr_mat/)). There are no specific requirements on where the files should be saved at. | ||
|
||
2. Prepare a CSV file ([example](dataset/ABIDE/meta.csv)) which contains the required columns below: | ||
|
||
- ``SUBJECT``: A unique identifier for different subjects | ||
- ``AGE``: The age of subjects when the fMRI scan is acquired | ||
- ``GENDER``: The gender of subjects | ||
- ``DISEASE_LABEL``: The label for disease classification (0 represents cognitive normal subjects, 1 represents diseased subjects) | ||
- ``SITE``: The site in which the subject scan is acquired | ||
- ``FC_MATRIX_PATH``: The paths to the ``.npy`` files that store the functional connectivity matrices of subjects. This path can either be absolute local path or relative path from the directory of the CSV file. | ||
|
||
## Run the Code | ||
|
||
1. Modify the ``config.yml`` file ([example](src/config.yml)) or create a new ``config.yml`` file as the input to the ``main.py`` script. The ``config.yml`` file contains the necessary arguments required to run the main script. | ||
|
||
- ``output_directory``: The directory in which the results should be stored at | ||
- ``model_name``: The name to be assigned to the model | ||
- ``model_params``: The parameters for initializing the EDC-VAE model, including | ||
- ``hidden_size``: The number of hidden nodes for encoder and decoder. | ||
- ``emb_size``: The dimension of the latent space representation output by the encoder. | ||
- ``clf_hidden_1``: The number of hidden nodes in the first hidden layer of classifier. | ||
- ``clf_hidden_2``: The number of hidden nodes in the second hidden layer of classifier. | ||
- ``dropout``: The amount of dropout during training. | ||
- ``optim_params``: The parameters for initializing Adam optimizer | ||
- ``lr``: The learning rate | ||
- ``l2_reg``: The L2 regularization | ||
- ``hyperparameters``: Additional hyperparameters when training EDC-VAE model | ||
- ``ll_loss``: The weightage of VAE likelihood loss | ||
- ``kl_loss``: The weightage of KL divergence | ||
- ``dataset_path``: Path to the CSV prepared during dataset preparation stage. This path can be an absolute path, or a relative path from the directory of ``main.py`` | ||
- ``dataset_name``: The name to be assigned to the dataset | ||
- ``seeds``: A list of seeds to iterate through | ||
- ``num_fold``: The number of folds for cross validation | ||
- ``ssl``: A boolean, indicating whether unlabeled data should be used to train the EDC-VAE model | ||
- ``harmonize``: A boolean, indicating whether ComBat harmonization should be performed prior to model training. | ||
- ``labeled_sites``: The site used as labeled data, when ``null``, all sites are used as labeled data. | ||
- ``device``: The GPU to be used, ``-1`` means to use CPU only | ||
- ``verbose``: A boolean, indicating Whether to display training progress messages | ||
- ``max_epoch``: The maximum number of epochs | ||
- ``patience``: The number of epochs without improvement for early stopping | ||
- ``save_model``: A boolean, indicating whether to save EDC-VAE's state_dict. | ||
|
||
2. Run the main script | ||
|
||
cd src | ||
python main.py --config <PATH_TO_YML_FILE> | ||
|
||
3. Load a saved model and perform inference | ||
|
||
python evaluate.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
SUBJECT,AGE,GENDER,DISEASE_LABEL,SITE,FC_MATRIX_PATH | ||
00001,1,1,1,NYU,processed_corr_mat/NYU_0000001_power.npy | ||
00002,2,1,0,NYU,processed_corr_mat/NYU_0000002_power.npy | ||
00003,3,0,1,NYU,processed_corr_mat/NYU_0000003_power.npy | ||
00004,4,0,0,NYU,processed_corr_mat/NYU_0000004_power.npy | ||
00005,5,1,1,NYU,processed_corr_mat/NYU_0000005_power.npy | ||
00006,6,0,1,NYU,processed_corr_mat/NYU_0000006_power.npy | ||
00007,7,1,0,NYU,processed_corr_mat/NYU_0000007_power.npy | ||
00008,8,0,1,NYU,processed_corr_mat/NYU_0000008_power.npy | ||
00009,9,0,0,NYU,processed_corr_mat/NYU_0000009_power.npy | ||
00010,10,1,0,NYU,processed_corr_mat/NYU_0000010_power.npy | ||
00011,1,1,1,USM,processed_corr_mat/USM_0000011_power.npy | ||
00012,2,0,1,USM,processed_corr_mat/USM_0000012_power.npy | ||
00013,3,1,0,USM,processed_corr_mat/USM_0000013_power.npy | ||
00014,4,0,0,USM,processed_corr_mat/USM_0000014_power.npy | ||
00015,5,1,1,USM,processed_corr_mat/USM_0000015_power.npy | ||
00016,6,1,0,USM,processed_corr_mat/USM_0000016_power.npy | ||
00017,7,0,1,USM,processed_corr_mat/USM_0000017_power.npy | ||
00018,8,0,0,USM,processed_corr_mat/USM_0000018_power.npy | ||
00019,9,1,0,USM,processed_corr_mat/USM_0000019_power.npy | ||
00020,10,0,1,USM,processed_corr_mat/USM_0000020_power.npy |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Data provided in this folder is randomly generated. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
For `hyperparameters__ch_loss_accuracy_boxplot.png`, 1.0 was used instead as the number of samples for 0.001 was too low (4). |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
numpy | ||
pandas | ||
pyarrow | ||
psutil | ||
scipy | ||
seaborn | ||
matplotlib | ||
jupyter | ||
pytorch==1.8.0 | ||
torchvision==0.9.0 | ||
torchaudio==0.8.0 | ||
cudatoolkit==10.1.243 | ||
pytorch-geometric | ||
captum | ||
nilearn |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
conda install --file requirements.txt -c rusty1s -c conda-forge -c pytorch | ||
pip install neuroCombat --no-cache |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
output_directory: ../results | ||
model_name: EDC_VAE | ||
model_params: | ||
hidden_size: 32 | ||
emb_size: 32 | ||
clf_hidden_1: 0 | ||
clf_hidden_2: 0 | ||
dropout: 0.2 | ||
optim_params: | ||
lr: 0.002 | ||
l2_reg: 0.001 | ||
hyperparameters: | ||
ll_loss: 0.0001 | ||
kl_loss: 0.001 | ||
dataset_path: ../dataset/ABIDE/meta.csv | ||
dataset_name: ABIDE | ||
seeds: | ||
- 0 | ||
- 1 | ||
- 2 | ||
- 3 | ||
- 4 | ||
- 5 | ||
- 6 | ||
- 7 | ||
- 8 | ||
- 9 | ||
num_fold: 5 | ||
ssl: true | ||
harmonize: true | ||
labeled_sites: NYU | ||
device: 2 | ||
verbose: true | ||
max_epoch: 1000 | ||
patience: 1000 | ||
save_model: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
import os | ||
import pandas as pd | ||
import numpy as np | ||
from neuroCombat import neuroCombat | ||
from sklearn.model_selection import StratifiedKFold | ||
from sklearn.preprocessing import LabelEncoder | ||
from typing import Dict, Generator, Optional, Sequence, Union | ||
|
||
import torch | ||
from torch_geometric.data import Data | ||
|
||
|
||
def corr_mx_flatten(X: np.ndarray) -> np.ndarray: | ||
""" | ||
returns upper triangluar matrix of each sample in X | ||
option 1: | ||
X.shape == (num_sample, num_feature, num_feature) | ||
X_flattened.shape == (num_sample, num_feature * (num_feature - 1) / 2) | ||
option 2: | ||
X.shape == (num_feature, num_feature) | ||
X_flattend.shape == (num_feature * (num_feature - 1) / 2,) | ||
""" | ||
upper_triangular_idx = np.triu_indices(X.shape[1], 1) | ||
if len(X.shape) == 3: | ||
X = X[:, upper_triangular_idx[0], upper_triangular_idx[1]] | ||
else: | ||
X = X[upper_triangular_idx[0], upper_triangular_idx[1]] | ||
return X | ||
|
||
|
||
def combat_harmonization(X: np.ndarray, meta_df: pd.DataFrame) -> np.ndarray: | ||
covars = meta_df[["SITE", "AGE", "GENDER"]] | ||
categorical_cols = ["GENDER"] | ||
continuous_cols = ["AGE"] | ||
batch_col = "SITE" | ||
combat = neuroCombat( | ||
dat=X.T, | ||
covars=covars, | ||
batch_col=batch_col, | ||
categorical_cols=categorical_cols, | ||
continuous_cols=continuous_cols, | ||
) | ||
harmonized_X = combat["data"].T | ||
harmonized_X = np.clip(harmonized_X, -1, 1) | ||
return harmonized_X | ||
|
||
|
||
def split_kfoldcv_sbj( | ||
y: np.ndarray, subjects: np.ndarray, num_fold: int, seed: int | ||
): | ||
unique_subjects, first_subject_index = np.unique( | ||
subjects, return_index=True | ||
) | ||
subject_y = y[first_subject_index] | ||
subject_X = np.zeros_like(subject_y) | ||
skfold = StratifiedKFold(n_splits=num_fold, random_state=seed, shuffle=True) | ||
|
||
result = [] | ||
for train_subject_index, test_subject_index in skfold.split( | ||
subject_X, subject_y | ||
): | ||
train_subjects = unique_subjects[train_subject_index] | ||
test_subjects = unique_subjects[test_subject_index] | ||
train_index = np.argwhere(np.isin(subjects, train_subjects)).flatten() | ||
test_index = np.argwhere(np.isin(subjects, test_subjects)).flatten() | ||
assert len(np.intersect1d(train_index, test_index)) == 0 | ||
assert ( | ||
len(np.intersect1d(subjects[train_index], subjects[test_index])) | ||
== 0 | ||
) | ||
result.append((train_index, test_index)) | ||
return result | ||
|
||
|
||
def make_dataset( | ||
x: np.ndarray, | ||
y: np.ndarray, | ||
d: np.ndarray, | ||
age: np.ndarray, | ||
gender: np.ndarray, | ||
) -> Data: | ||
graph = Data() | ||
graph.x = torch.tensor(x).float() | ||
graph.y = torch.tensor(y) | ||
graph.d = torch.tensor(d) | ||
graph.age = torch.tensor(age).float() | ||
graph.gender = torch.tensor(gender).float() | ||
return graph | ||
|
||
|
||
class Dataset: | ||
""" | ||
required columns in meta.csv | ||
- SUBJECT | ||
- AGE | ||
- GENDER | ||
- SITE | ||
- DISEASE_LABEL | ||
- FC_MATRIX_PATH | ||
""" | ||
|
||
def __init__(self, data_csv_path: str, name: str, harmonize: bool = False): | ||
self.data_csv_path = os.path.abspath(data_csv_path) | ||
self.data_folder = os.path.dirname(self.data_csv_path) | ||
self.name = name | ||
self.harmonize = harmonize | ||
self._init_properties_() | ||
|
||
def _init_properties_(self): | ||
meta_df = pd.read_csv(self.data_csv_path) | ||
X = np.array( | ||
[ | ||
np.load(os.path.join(self.data_folder, path)) | ||
for path in meta_df["FC_MATRIX_PATH"] | ||
] | ||
) | ||
X = corr_mx_flatten(np.nan_to_num(X)) | ||
if self.harmonize: | ||
self.X = combat_harmonization(X, meta_df) | ||
else: | ||
self.X = X | ||
|
||
self.subjects = meta_df["SUBJECT"].values | ||
self.ages = meta_df["AGE"].values | ||
self.genders = meta_df["GENDER"].values | ||
self.sites = meta_df["SITE"].values | ||
self.y = meta_df["DISEASE_LABEL"].values | ||
|
||
def _get_indices( | ||
self, | ||
seed: int = 0, | ||
num_fold: int = 5, | ||
ssl: bool = False, | ||
labeled_sites: Optional[Union[str, Sequence[str]]] = None, | ||
) -> Generator[Dict[str, np.ndarray], None, None]: | ||
|
||
if labeled_sites is None: | ||
is_labeled = np.ones(len(self.sites), dtype=bool) | ||
else: | ||
if isinstance(labeled_sites, str): | ||
labeled_sites = [labeled_sites] | ||
is_labeled = np.isin(self.sites, labeled_sites) | ||
|
||
unlabeled_indices = np.argwhere(~is_labeled).flatten() | ||
labeled_indices = np.argwhere(is_labeled).flatten() | ||
for train, test in split_kfoldcv_sbj( | ||
self.y[is_labeled], self.subjects[is_labeled], num_fold, seed | ||
): | ||
result = { | ||
"labeled_train": labeled_indices[train], | ||
"test": labeled_indices[test], | ||
} | ||
if ssl: | ||
result["unlabeled_train"] = unlabeled_indices | ||
yield result | ||
|
||
def load_split_data( | ||
self, | ||
seed: int = 0, | ||
num_fold: int = 5, | ||
ssl: bool = False, | ||
labeled_sites: Optional[Union[str, Sequence[str]]] = None, | ||
) -> Generator[Dict[str, Union[int, Data]], None, None]: | ||
for indices in self._get_indices(seed, num_fold, ssl, labeled_sites): | ||
if ssl: | ||
all_train_indices = np.concatenate( | ||
(indices["labeled_train"], indices["unlabeled_train"]) | ||
) | ||
else: | ||
all_train_indices = indices["labeled_train"] | ||
le = LabelEncoder() | ||
le.fit(self.sites[all_train_indices]) | ||
|
||
all_data: Dict[str, Data] = dict() | ||
for name, idx in indices.items(): | ||
all_data[name] = make_dataset( | ||
x=self.X[idx], | ||
y=self.y[idx], | ||
d=le.transform(self.sites[idx]), | ||
age=self.ages[idx], | ||
gender=self.genders[idx], | ||
) | ||
|
||
all_data["input_size"] = int(self.X.shape[1]) | ||
all_data["num_sites"] = int(len(le.classes_)) | ||
|
||
empty = Data(x=torch.tensor([])) | ||
all_data["num_labeled_train"] = all_data.get( | ||
"labeled_train", empty | ||
).x.size(0) | ||
all_data["num_unlabeled_train"] = all_data.get( | ||
"unlabeled_train", empty | ||
).x.size(0) | ||
all_data["num_test"] = all_data.get("test", empty).x.size(0) | ||
yield all_data | ||
|
||
def load_all_data( | ||
self, sites: Optional[Union[str, Sequence[str]]] = None, | ||
) -> Dict[str, Union[int, Data]]: | ||
if isinstance(sites, str): | ||
sites = [sites] | ||
all_indices = np.arange(len(self.X)) | ||
if sites is not None: | ||
all_indices = all_indices[np.isin(self.sites, sites)] | ||
|
||
le = LabelEncoder() | ||
le.fit(self.sites[all_indices]) | ||
|
||
return { | ||
"data": make_dataset( | ||
x=self.X[all_indices], | ||
y=self.y[all_indices], | ||
d=le.transform(self.sites[all_indices]), | ||
age=self.ages[all_indices], | ||
gender=self.genders[all_indices], | ||
), | ||
"input_size": int(self.X.shape[1]), | ||
"num_sites": int(len(le.classes_)), | ||
} |
Oops, something went wrong.