Skip to content

Commit

Permalink
initial upload
Browse files Browse the repository at this point in the history
  • Loading branch information
yihao001 authored Jun 18, 2022
0 parents commit 3fde806
Show file tree
Hide file tree
Showing 44 changed files with 1,396 additions and 0 deletions.
76 changes: 76 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Semi-supervised learning with data harmonisation for biomarker discovery from resting state fMRI

Code for SHRED-II is provided in this repository.

Models were developed in Python using Pytorch v1.8.0. The experiments were performed on an Nvidia P100 GPU.
- Information on sensitivity regarding parameter changes can be found in `./figures/`.
- Expected runtime: Depending on the size of the chosen site, a full run of 10 seeds x 5 folds for a single site can take between 1 to 3 minutes.
- Memory footprint: ~4GB of GPU memory

Data is available for download at the following links: [ABIDE](http://preprocessed-connectomes-project.org/abide/), [ADHD](http://preprocessed-connectomes-project.org/adhd200/).


## Environment Setup

1. Create and activate new Anaconda environment

conda create -n <env_name> python=3.8
conda activate <env_name>

2. Run ``setup.sh``

chmod u+x ./setup.sh
./setup.sh

## Dataset Preparation

1. Process the raw fMRI data to obtain a functional connectivity matrix for each subject scan. The functional connectivity matrices should be saved as ``.npy`` files ([example](dataset/ABIDE/processed_corr_mat/)). There are no specific requirements on where the files should be saved at.

2. Prepare a CSV file ([example](dataset/ABIDE/meta.csv)) which contains the required columns below:

- ``SUBJECT``: A unique identifier for different subjects
- ``AGE``: The age of subjects when the fMRI scan is acquired
- ``GENDER``: The gender of subjects
- ``DISEASE_LABEL``: The label for disease classification (0 represents cognitive normal subjects, 1 represents diseased subjects)
- ``SITE``: The site in which the subject scan is acquired
- ``FC_MATRIX_PATH``: The paths to the ``.npy`` files that store the functional connectivity matrices of subjects. This path can either be absolute local path or relative path from the directory of the CSV file.

## Run the Code

1. Modify the ``config.yml`` file ([example](src/config.yml)) or create a new ``config.yml`` file as the input to the ``main.py`` script. The ``config.yml`` file contains the necessary arguments required to run the main script.

- ``output_directory``: The directory in which the results should be stored at
- ``model_name``: The name to be assigned to the model
- ``model_params``: The parameters for initializing the EDC-VAE model, including
- ``hidden_size``: The number of hidden nodes for encoder and decoder.
- ``emb_size``: The dimension of the latent space representation output by the encoder.
- ``clf_hidden_1``: The number of hidden nodes in the first hidden layer of classifier.
- ``clf_hidden_2``: The number of hidden nodes in the second hidden layer of classifier.
- ``dropout``: The amount of dropout during training.
- ``optim_params``: The parameters for initializing Adam optimizer
- ``lr``: The learning rate
- ``l2_reg``: The L2 regularization
- ``hyperparameters``: Additional hyperparameters when training EDC-VAE model
- ``ll_loss``: The weightage of VAE likelihood loss
- ``kl_loss``: The weightage of KL divergence
- ``dataset_path``: Path to the CSV prepared during dataset preparation stage. This path can be an absolute path, or a relative path from the directory of ``main.py``
- ``dataset_name``: The name to be assigned to the dataset
- ``seeds``: A list of seeds to iterate through
- ``num_fold``: The number of folds for cross validation
- ``ssl``: A boolean, indicating whether unlabeled data should be used to train the EDC-VAE model
- ``harmonize``: A boolean, indicating whether ComBat harmonization should be performed prior to model training.
- ``labeled_sites``: The site used as labeled data, when ``null``, all sites are used as labeled data.
- ``device``: The GPU to be used, ``-1`` means to use CPU only
- ``verbose``: A boolean, indicating Whether to display training progress messages
- ``max_epoch``: The maximum number of epochs
- ``patience``: The number of epochs without improvement for early stopping
- ``save_model``: A boolean, indicating whether to save EDC-VAE's state_dict.

2. Run the main script

cd src
python main.py --config <PATH_TO_YML_FILE>

3. Load a saved model and perform inference

python evaluate.py
21 changes: 21 additions & 0 deletions dataset/ABIDE/meta.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
SUBJECT,AGE,GENDER,DISEASE_LABEL,SITE,FC_MATRIX_PATH
00001,1,1,1,NYU,processed_corr_mat/NYU_0000001_power.npy
00002,2,1,0,NYU,processed_corr_mat/NYU_0000002_power.npy
00003,3,0,1,NYU,processed_corr_mat/NYU_0000003_power.npy
00004,4,0,0,NYU,processed_corr_mat/NYU_0000004_power.npy
00005,5,1,1,NYU,processed_corr_mat/NYU_0000005_power.npy
00006,6,0,1,NYU,processed_corr_mat/NYU_0000006_power.npy
00007,7,1,0,NYU,processed_corr_mat/NYU_0000007_power.npy
00008,8,0,1,NYU,processed_corr_mat/NYU_0000008_power.npy
00009,9,0,0,NYU,processed_corr_mat/NYU_0000009_power.npy
00010,10,1,0,NYU,processed_corr_mat/NYU_0000010_power.npy
00011,1,1,1,USM,processed_corr_mat/USM_0000011_power.npy
00012,2,0,1,USM,processed_corr_mat/USM_0000012_power.npy
00013,3,1,0,USM,processed_corr_mat/USM_0000013_power.npy
00014,4,0,0,USM,processed_corr_mat/USM_0000014_power.npy
00015,5,1,1,USM,processed_corr_mat/USM_0000015_power.npy
00016,6,1,0,USM,processed_corr_mat/USM_0000016_power.npy
00017,7,0,1,USM,processed_corr_mat/USM_0000017_power.npy
00018,8,0,0,USM,processed_corr_mat/USM_0000018_power.npy
00019,9,1,0,USM,processed_corr_mat/USM_0000019_power.npy
00020,10,0,1,USM,processed_corr_mat/USM_0000020_power.npy
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1 change: 1 addition & 0 deletions dataset/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Data provided in this folder is randomly generated.
1 change: 1 addition & 0 deletions figures/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
For `hyperparameters__ch_loss_accuracy_boxplot.png`, 1.0 was used instead as the number of samples for 0.001 was too low (4).
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figures/optim_params__lr_accuracy_boxplot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
15 changes: 15 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
numpy
pandas
pyarrow
psutil
scipy
seaborn
matplotlib
jupyter
pytorch==1.8.0
torchvision==0.9.0
torchaudio==0.8.0
cudatoolkit==10.1.243
pytorch-geometric
captum
nilearn
Binary file added saved_model/ABIDE_VAE-FFN_0_0_1645419832.pt
Binary file not shown.
2 changes: 2 additions & 0 deletions setup.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
conda install --file requirements.txt -c rusty1s -c conda-forge -c pytorch
pip install neuroCombat --no-cache
36 changes: 36 additions & 0 deletions src/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
output_directory: ../results
model_name: EDC_VAE
model_params:
hidden_size: 32
emb_size: 32
clf_hidden_1: 0
clf_hidden_2: 0
dropout: 0.2
optim_params:
lr: 0.002
l2_reg: 0.001
hyperparameters:
ll_loss: 0.0001
kl_loss: 0.001
dataset_path: ../dataset/ABIDE/meta.csv
dataset_name: ABIDE
seeds:
- 0
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
num_fold: 5
ssl: true
harmonize: true
labeled_sites: NYU
device: 2
verbose: true
max_epoch: 1000
patience: 1000
save_model: true
221 changes: 221 additions & 0 deletions src/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import os
import pandas as pd
import numpy as np
from neuroCombat import neuroCombat
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from typing import Dict, Generator, Optional, Sequence, Union

import torch
from torch_geometric.data import Data


def corr_mx_flatten(X: np.ndarray) -> np.ndarray:
"""
returns upper triangluar matrix of each sample in X
option 1:
X.shape == (num_sample, num_feature, num_feature)
X_flattened.shape == (num_sample, num_feature * (num_feature - 1) / 2)
option 2:
X.shape == (num_feature, num_feature)
X_flattend.shape == (num_feature * (num_feature - 1) / 2,)
"""
upper_triangular_idx = np.triu_indices(X.shape[1], 1)
if len(X.shape) == 3:
X = X[:, upper_triangular_idx[0], upper_triangular_idx[1]]
else:
X = X[upper_triangular_idx[0], upper_triangular_idx[1]]
return X


def combat_harmonization(X: np.ndarray, meta_df: pd.DataFrame) -> np.ndarray:
covars = meta_df[["SITE", "AGE", "GENDER"]]
categorical_cols = ["GENDER"]
continuous_cols = ["AGE"]
batch_col = "SITE"
combat = neuroCombat(
dat=X.T,
covars=covars,
batch_col=batch_col,
categorical_cols=categorical_cols,
continuous_cols=continuous_cols,
)
harmonized_X = combat["data"].T
harmonized_X = np.clip(harmonized_X, -1, 1)
return harmonized_X


def split_kfoldcv_sbj(
y: np.ndarray, subjects: np.ndarray, num_fold: int, seed: int
):
unique_subjects, first_subject_index = np.unique(
subjects, return_index=True
)
subject_y = y[first_subject_index]
subject_X = np.zeros_like(subject_y)
skfold = StratifiedKFold(n_splits=num_fold, random_state=seed, shuffle=True)

result = []
for train_subject_index, test_subject_index in skfold.split(
subject_X, subject_y
):
train_subjects = unique_subjects[train_subject_index]
test_subjects = unique_subjects[test_subject_index]
train_index = np.argwhere(np.isin(subjects, train_subjects)).flatten()
test_index = np.argwhere(np.isin(subjects, test_subjects)).flatten()
assert len(np.intersect1d(train_index, test_index)) == 0
assert (
len(np.intersect1d(subjects[train_index], subjects[test_index]))
== 0
)
result.append((train_index, test_index))
return result


def make_dataset(
x: np.ndarray,
y: np.ndarray,
d: np.ndarray,
age: np.ndarray,
gender: np.ndarray,
) -> Data:
graph = Data()
graph.x = torch.tensor(x).float()
graph.y = torch.tensor(y)
graph.d = torch.tensor(d)
graph.age = torch.tensor(age).float()
graph.gender = torch.tensor(gender).float()
return graph


class Dataset:
"""
required columns in meta.csv
- SUBJECT
- AGE
- GENDER
- SITE
- DISEASE_LABEL
- FC_MATRIX_PATH
"""

def __init__(self, data_csv_path: str, name: str, harmonize: bool = False):
self.data_csv_path = os.path.abspath(data_csv_path)
self.data_folder = os.path.dirname(self.data_csv_path)
self.name = name
self.harmonize = harmonize
self._init_properties_()

def _init_properties_(self):
meta_df = pd.read_csv(self.data_csv_path)
X = np.array(
[
np.load(os.path.join(self.data_folder, path))
for path in meta_df["FC_MATRIX_PATH"]
]
)
X = corr_mx_flatten(np.nan_to_num(X))
if self.harmonize:
self.X = combat_harmonization(X, meta_df)
else:
self.X = X

self.subjects = meta_df["SUBJECT"].values
self.ages = meta_df["AGE"].values
self.genders = meta_df["GENDER"].values
self.sites = meta_df["SITE"].values
self.y = meta_df["DISEASE_LABEL"].values

def _get_indices(
self,
seed: int = 0,
num_fold: int = 5,
ssl: bool = False,
labeled_sites: Optional[Union[str, Sequence[str]]] = None,
) -> Generator[Dict[str, np.ndarray], None, None]:

if labeled_sites is None:
is_labeled = np.ones(len(self.sites), dtype=bool)
else:
if isinstance(labeled_sites, str):
labeled_sites = [labeled_sites]
is_labeled = np.isin(self.sites, labeled_sites)

unlabeled_indices = np.argwhere(~is_labeled).flatten()
labeled_indices = np.argwhere(is_labeled).flatten()
for train, test in split_kfoldcv_sbj(
self.y[is_labeled], self.subjects[is_labeled], num_fold, seed
):
result = {
"labeled_train": labeled_indices[train],
"test": labeled_indices[test],
}
if ssl:
result["unlabeled_train"] = unlabeled_indices
yield result

def load_split_data(
self,
seed: int = 0,
num_fold: int = 5,
ssl: bool = False,
labeled_sites: Optional[Union[str, Sequence[str]]] = None,
) -> Generator[Dict[str, Union[int, Data]], None, None]:
for indices in self._get_indices(seed, num_fold, ssl, labeled_sites):
if ssl:
all_train_indices = np.concatenate(
(indices["labeled_train"], indices["unlabeled_train"])
)
else:
all_train_indices = indices["labeled_train"]
le = LabelEncoder()
le.fit(self.sites[all_train_indices])

all_data: Dict[str, Data] = dict()
for name, idx in indices.items():
all_data[name] = make_dataset(
x=self.X[idx],
y=self.y[idx],
d=le.transform(self.sites[idx]),
age=self.ages[idx],
gender=self.genders[idx],
)

all_data["input_size"] = int(self.X.shape[1])
all_data["num_sites"] = int(len(le.classes_))

empty = Data(x=torch.tensor([]))
all_data["num_labeled_train"] = all_data.get(
"labeled_train", empty
).x.size(0)
all_data["num_unlabeled_train"] = all_data.get(
"unlabeled_train", empty
).x.size(0)
all_data["num_test"] = all_data.get("test", empty).x.size(0)
yield all_data

def load_all_data(
self, sites: Optional[Union[str, Sequence[str]]] = None,
) -> Dict[str, Union[int, Data]]:
if isinstance(sites, str):
sites = [sites]
all_indices = np.arange(len(self.X))
if sites is not None:
all_indices = all_indices[np.isin(self.sites, sites)]

le = LabelEncoder()
le.fit(self.sites[all_indices])

return {
"data": make_dataset(
x=self.X[all_indices],
y=self.y[all_indices],
d=le.transform(self.sites[all_indices]),
age=self.ages[all_indices],
gender=self.genders[all_indices],
),
"input_size": int(self.X.shape[1]),
"num_sites": int(len(le.classes_)),
}
Loading

0 comments on commit 3fde806

Please sign in to comment.