-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
63 changed files
with
2,754 additions
and
526 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,76 +1,34 @@ | ||
# Semi-supervised learning with data harmonisation for biomarker discovery from resting state fMRI | ||
# SHRED | ||
|
||
Code for SHRED-II is provided in this repository. | ||
This repo contains the code for various variants of the SHRED architecture. | ||
|
||
Models were developed in Python using Pytorch v1.8.0. The experiments were performed on an Nvidia P100 GPU. | ||
- Information on sensitivity regarding parameter changes can be found in `./figures/`. | ||
- Expected runtime: Depending on the size of the chosen site, a full run of 10 seeds x 5 folds for a single site can take between 1 to 3 minutes. | ||
- Memory footprint: ~4GB of GPU memory | ||
|
||
Data is available for download at the following links: [ABIDE](http://preprocessed-connectomes-project.org/abide/), [ADHD](http://preprocessed-connectomes-project.org/adhd200/). | ||
A previous version of the repo, linked to our MICCAI 2022 submission, can be found on the branch `miccai_2022` [link](https://github.com/SCSE-Biomedical-Computing-Group/SHRED/tree/miccai_2022). | ||
|
||
Data is available for download at the following links: [SchizConnect](http://schizconnect.org), [UCLA](https://openneuro.org/datasets/ds000030/versions/00016). | ||
Some sites in SchizConnect seem to be down for some time. | ||
|
||
## Environment Setup | ||
|
||
1. Create and activate new Anaconda environment | ||
1. Create and activate new conda environment | ||
|
||
conda create -n <env_name> python=3.8 | ||
conda activate <env_name> | ||
|
||
2. Run ``setup.sh`` | ||
2. Run `setup.sh` | ||
|
||
chmod u+x ./setup.sh | ||
./setup.sh | ||
## Setup for a new dataset | ||
|
||
## Dataset Preparation | ||
|
||
1. Process the raw fMRI data to obtain a functional connectivity matrix for each subject scan. The functional connectivity matrices should be saved as ``.npy`` files ([example](dataset/ABIDE/processed_corr_mat/)). There are no specific requirements on where the files should be saved at. | ||
|
||
2. Prepare a CSV file ([example](dataset/ABIDE/meta.csv)) which contains the required columns below: | ||
|
||
- ``SUBJECT``: A unique identifier for different subjects | ||
- ``AGE``: The age of subjects when the fMRI scan is acquired | ||
- ``GENDER``: The gender of subjects | ||
- ``DISEASE_LABEL``: The label for disease classification (0 represents cognitive normal subjects, 1 represents diseased subjects) | ||
- ``SITE``: The site in which the subject scan is acquired | ||
- ``FC_MATRIX_PATH``: The paths to the ``.npy`` files that store the functional connectivity matrices of subjects. This path can either be absolute local path or relative path from the directory of the CSV file. | ||
|
||
## Run the Code | ||
|
||
1. Modify the ``config.yml`` file ([example](src/config.yml)) or create a new ``config.yml`` file as the input to the ``main.py`` script. The ``config.yml`` file contains the necessary arguments required to run the main script. | ||
|
||
- ``output_directory``: The directory in which the results should be stored at | ||
- ``model_name``: The name to be assigned to the model | ||
- ``model_params``: The parameters for initializing the EDC-VAE model, including | ||
- ``hidden_size``: The number of hidden nodes for encoder and decoder. | ||
- ``emb_size``: The dimension of the latent space representation output by the encoder. | ||
- ``clf_hidden_1``: The number of hidden nodes in the first hidden layer of classifier. | ||
- ``clf_hidden_2``: The number of hidden nodes in the second hidden layer of classifier. | ||
- ``dropout``: The amount of dropout during training. | ||
- ``optim_params``: The parameters for initializing Adam optimizer | ||
- ``lr``: The learning rate | ||
- ``l2_reg``: The L2 regularization | ||
- ``hyperparameters``: Additional hyperparameters when training EDC-VAE model | ||
- ``ll_loss``: The weightage of VAE likelihood loss | ||
- ``kl_loss``: The weightage of KL divergence | ||
- ``dataset_path``: Path to the CSV prepared during dataset preparation stage. This path can be an absolute path, or a relative path from the directory of ``main.py`` | ||
- ``dataset_name``: The name to be assigned to the dataset | ||
- ``seeds``: A list of seeds to iterate through | ||
- ``num_fold``: The number of folds for cross validation | ||
- ``ssl``: A boolean, indicating whether unlabeled data should be used to train the EDC-VAE model | ||
- ``harmonize``: A boolean, indicating whether ComBat harmonization should be performed prior to model training. | ||
- ``labeled_sites``: The site used as labeled data, when ``null``, all sites are used as labeled data. | ||
- ``device``: The GPU to be used, ``-1`` means to use CPU only | ||
- ``verbose``: A boolean, indicating Whether to display training progress messages | ||
- ``max_epoch``: The maximum number of epochs | ||
- ``patience``: The number of epochs without improvement for early stopping | ||
- ``save_model``: A boolean, indicating whether to save EDC-VAE's state_dict. | ||
1. Prepare dataset | ||
|
||
2. Run the main script | ||
- Create a new folder under `./src` with the dataset name (see `./Schiz` for reference) and modify the setup and config files. | ||
- Edit `__init__.py` to specify how to retrieve site, age and gender. Labelling standards too, if applicable. | ||
- Add dataset to `DataloaderBase` class (`_get_indices()` too) and `Dataset` class in `./src/data.py` | ||
|
||
cd src | ||
python main.py --config <PATH_TO_YML_FILE> | ||
2. Create `.yml` files in `config_template` to define model hyperparameters used and training settings. More details about the YAML files can be found in the `miccai_2022` branch. | ||
|
||
3. Load a saved model and perform inference | ||
3. Train the model (and any other models - specify in the `.yml` file) using `single_stage_framework.py`. | ||
|
||
python evaluate.py | ||
python single_stage_framework.py --config config_templates/individual/SHRED-III.yml |
This file was deleted.
Oops, something went wrong.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Schiz | ||
|
||
Data collected from SchizConnect (NMorphCH, COBRE), UCLA and a private dataset. | ||
|
||
## Setup Dataset Guide | ||
|
||
1. Make sure that the path in ``setup.py`` is correct | ||
|
||
main_dir = <Path to folder containing data> | ||
corr_mat_dir = <Path to folder containing correlation matrix> | ||
phenotypics_path = <Path to folder containing metadata> | ||
|
||
2. Run ``setup.py`` | ||
|
||
python setup.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import os | ||
import os | ||
import sys | ||
import numpy as np | ||
import pandas as pd | ||
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__))) | ||
from schiz_config import * | ||
|
||
|
||
def load_data_fmri(harmonized=False, time_series=False): | ||
""" | ||
Inputs | ||
site_id: str or None | ||
- if specified, splits will only contains index of subjects from the specified site | ||
harmonized: bool | ||
- whether or not to return combat harmonized X | ||
Returns | ||
X: np.ndarray with N subject samples, each sample consists of a ROI x ROI correlation matrix | ||
Y: np.ndarray with N subject samples, each sample has a one-hot encoded label (0: Normal, 1: Diseased) | ||
""" | ||
if harmonized: | ||
X = np.load(HARMONIZED_X_PATH) | ||
else: | ||
X = np.load(X_PATH) | ||
Y = np.load(Y_PATH) | ||
if not time_series: | ||
return X, Y | ||
raise NotImplementedError | ||
|
||
|
||
def get_splits(site_id=None, test=False): | ||
if site_id is None: | ||
path = SPLIT_TEST_PATH if test else SPLIT_CV_PATH | ||
else: | ||
path = "{}_test.npy".format(site_id) if test else "{}_cv.npy".format(site_id) | ||
path = os.path.join(SSL_SPLITS_DIR, path) | ||
splits = np.load(path, allow_pickle=True) | ||
return splits | ||
|
||
|
||
def get_ages_and_genders(): | ||
""" | ||
ages: np.array of float representing the age of subject when the scan is obtained | ||
gender: np.array of int representing the subject's gender | ||
- 0: Female | ||
- 1: Male | ||
""" | ||
meta_df = pd.read_csv(META_CSV_PATH) | ||
ages = np.array(meta_df["age"]) | ||
genders = np.array(meta_df["sex"]) | ||
return ages, genders | ||
|
||
|
||
def get_sites(): | ||
meta_df = pd.read_csv(META_CSV_PATH) | ||
sites = np.array(meta_df["study"]) | ||
return sites | ||
|
||
|
||
def load_meta_df(): | ||
df = pd.read_csv(META_CSV_PATH) | ||
df["FILE_PATH"] = df["FILE_PATH"].apply(eval) | ||
return df | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import os.path as osp | ||
|
||
|
||
__dir__ = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__)))) | ||
|
||
|
||
MAIN_DIR = osp.abspath(osp.join(__dir__, "data", "Schiz")) | ||
META_CSV_PATH = osp.join(MAIN_DIR, "meta.csv") | ||
X_PATH = osp.join(MAIN_DIR, "X.npy") | ||
X_TS_PATH = osp.join(MAIN_DIR, "X_ts.npy") | ||
Y_PATH = osp.join(MAIN_DIR, "Y.npy") | ||
SPLIT_TEST_PATH = osp.join(MAIN_DIR, "split_test.npy") | ||
SPLIT_CV_PATH = osp.join(MAIN_DIR, "split_cv.npy") | ||
SSL_SPLITS_DIR = osp.join(MAIN_DIR, "ssl_splits") | ||
HARMONIZED_X_PATH = osp.join(MAIN_DIR, "harmonized_X.npy") |
Oops, something went wrong.