Skip to content

Commit

Permalink
updates for journal
Browse files Browse the repository at this point in the history
  • Loading branch information
manzaigit committed Aug 5, 2023
1 parent 3fde806 commit 9bdd01b
Show file tree
Hide file tree
Showing 63 changed files with 2,754 additions and 526 deletions.
74 changes: 16 additions & 58 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,76 +1,34 @@
# Semi-supervised learning with data harmonisation for biomarker discovery from resting state fMRI
# SHRED

Code for SHRED-II is provided in this repository.
This repo contains the code for various variants of the SHRED architecture.

Models were developed in Python using Pytorch v1.8.0. The experiments were performed on an Nvidia P100 GPU.
- Information on sensitivity regarding parameter changes can be found in `./figures/`.
- Expected runtime: Depending on the size of the chosen site, a full run of 10 seeds x 5 folds for a single site can take between 1 to 3 minutes.
- Memory footprint: ~4GB of GPU memory

Data is available for download at the following links: [ABIDE](http://preprocessed-connectomes-project.org/abide/), [ADHD](http://preprocessed-connectomes-project.org/adhd200/).
A previous version of the repo, linked to our MICCAI 2022 submission, can be found on the branch `miccai_2022` [link](https://github.com/SCSE-Biomedical-Computing-Group/SHRED/tree/miccai_2022).

Data is available for download at the following links: [SchizConnect](http://schizconnect.org), [UCLA](https://openneuro.org/datasets/ds000030/versions/00016).
Some sites in SchizConnect seem to be down for some time.

## Environment Setup

1. Create and activate new Anaconda environment
1. Create and activate new conda environment

conda create -n <env_name> python=3.8
conda activate <env_name>

2. Run ``setup.sh``
2. Run `setup.sh`

chmod u+x ./setup.sh
./setup.sh
## Setup for a new dataset

## Dataset Preparation

1. Process the raw fMRI data to obtain a functional connectivity matrix for each subject scan. The functional connectivity matrices should be saved as ``.npy`` files ([example](dataset/ABIDE/processed_corr_mat/)). There are no specific requirements on where the files should be saved at.

2. Prepare a CSV file ([example](dataset/ABIDE/meta.csv)) which contains the required columns below:

- ``SUBJECT``: A unique identifier for different subjects
- ``AGE``: The age of subjects when the fMRI scan is acquired
- ``GENDER``: The gender of subjects
- ``DISEASE_LABEL``: The label for disease classification (0 represents cognitive normal subjects, 1 represents diseased subjects)
- ``SITE``: The site in which the subject scan is acquired
- ``FC_MATRIX_PATH``: The paths to the ``.npy`` files that store the functional connectivity matrices of subjects. This path can either be absolute local path or relative path from the directory of the CSV file.

## Run the Code

1. Modify the ``config.yml`` file ([example](src/config.yml)) or create a new ``config.yml`` file as the input to the ``main.py`` script. The ``config.yml`` file contains the necessary arguments required to run the main script.

- ``output_directory``: The directory in which the results should be stored at
- ``model_name``: The name to be assigned to the model
- ``model_params``: The parameters for initializing the EDC-VAE model, including
- ``hidden_size``: The number of hidden nodes for encoder and decoder.
- ``emb_size``: The dimension of the latent space representation output by the encoder.
- ``clf_hidden_1``: The number of hidden nodes in the first hidden layer of classifier.
- ``clf_hidden_2``: The number of hidden nodes in the second hidden layer of classifier.
- ``dropout``: The amount of dropout during training.
- ``optim_params``: The parameters for initializing Adam optimizer
- ``lr``: The learning rate
- ``l2_reg``: The L2 regularization
- ``hyperparameters``: Additional hyperparameters when training EDC-VAE model
- ``ll_loss``: The weightage of VAE likelihood loss
- ``kl_loss``: The weightage of KL divergence
- ``dataset_path``: Path to the CSV prepared during dataset preparation stage. This path can be an absolute path, or a relative path from the directory of ``main.py``
- ``dataset_name``: The name to be assigned to the dataset
- ``seeds``: A list of seeds to iterate through
- ``num_fold``: The number of folds for cross validation
- ``ssl``: A boolean, indicating whether unlabeled data should be used to train the EDC-VAE model
- ``harmonize``: A boolean, indicating whether ComBat harmonization should be performed prior to model training.
- ``labeled_sites``: The site used as labeled data, when ``null``, all sites are used as labeled data.
- ``device``: The GPU to be used, ``-1`` means to use CPU only
- ``verbose``: A boolean, indicating Whether to display training progress messages
- ``max_epoch``: The maximum number of epochs
- ``patience``: The number of epochs without improvement for early stopping
- ``save_model``: A boolean, indicating whether to save EDC-VAE's state_dict.
1. Prepare dataset

2. Run the main script
- Create a new folder under `./src` with the dataset name (see `./Schiz` for reference) and modify the setup and config files.
- Edit `__init__.py` to specify how to retrieve site, age and gender. Labelling standards too, if applicable.
- Add dataset to `DataloaderBase` class (`_get_indices()` too) and `Dataset` class in `./src/data.py`

cd src
python main.py --config <PATH_TO_YML_FILE>
2. Create `.yml` files in `config_template` to define model hyperparameters used and training settings. More details about the YAML files can be found in the `miccai_2022` branch.

3. Load a saved model and perform inference
3. Train the model (and any other models - specify in the `.yml` file) using `single_stage_framework.py`.

python evaluate.py
python single_stage_framework.py --config config_templates/individual/SHRED-III.yml
21 changes: 0 additions & 21 deletions dataset/ABIDE/meta.csv

This file was deleted.

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1 change: 0 additions & 1 deletion dataset/README.md

This file was deleted.

1 change: 0 additions & 1 deletion figures/README.md

This file was deleted.

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed figures/model_params__dropout_accuracy_boxplot.png
Binary file not shown.
Binary file removed figures/model_params__emb_size_accuracy_boxplot.png
Binary file not shown.
Binary file not shown.
Binary file removed figures/optim_params__lr_accuracy_boxplot.png
Binary file not shown.
Binary file removed saved_model/ABIDE_VAE-FFN_0_0_1645419832.pt
Binary file not shown.
15 changes: 15 additions & 0 deletions src/Schiz/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Schiz

Data collected from SchizConnect (NMorphCH, COBRE), UCLA and a private dataset.

## Setup Dataset Guide

1. Make sure that the path in ``setup.py`` is correct

main_dir = <Path to folder containing data>
corr_mat_dir = <Path to folder containing correlation matrix>
phenotypics_path = <Path to folder containing metadata>

2. Run ``setup.py``

python setup.py
66 changes: 66 additions & 0 deletions src/Schiz/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import os
import sys
import numpy as np
import pandas as pd

sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from schiz_config import *


def load_data_fmri(harmonized=False, time_series=False):
"""
Inputs
site_id: str or None
- if specified, splits will only contains index of subjects from the specified site
harmonized: bool
- whether or not to return combat harmonized X
Returns
X: np.ndarray with N subject samples, each sample consists of a ROI x ROI correlation matrix
Y: np.ndarray with N subject samples, each sample has a one-hot encoded label (0: Normal, 1: Diseased)
"""
if harmonized:
X = np.load(HARMONIZED_X_PATH)
else:
X = np.load(X_PATH)
Y = np.load(Y_PATH)
if not time_series:
return X, Y
raise NotImplementedError


def get_splits(site_id=None, test=False):
if site_id is None:
path = SPLIT_TEST_PATH if test else SPLIT_CV_PATH
else:
path = "{}_test.npy".format(site_id) if test else "{}_cv.npy".format(site_id)
path = os.path.join(SSL_SPLITS_DIR, path)
splits = np.load(path, allow_pickle=True)
return splits


def get_ages_and_genders():
"""
ages: np.array of float representing the age of subject when the scan is obtained
gender: np.array of int representing the subject's gender
- 0: Female
- 1: Male
"""
meta_df = pd.read_csv(META_CSV_PATH)
ages = np.array(meta_df["age"])
genders = np.array(meta_df["sex"])
return ages, genders


def get_sites():
meta_df = pd.read_csv(META_CSV_PATH)
sites = np.array(meta_df["study"])
return sites


def load_meta_df():
df = pd.read_csv(META_CSV_PATH)
df["FILE_PATH"] = df["FILE_PATH"].apply(eval)
return df

15 changes: 15 additions & 0 deletions src/Schiz/schiz_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os.path as osp


__dir__ = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))


MAIN_DIR = osp.abspath(osp.join(__dir__, "data", "Schiz"))
META_CSV_PATH = osp.join(MAIN_DIR, "meta.csv")
X_PATH = osp.join(MAIN_DIR, "X.npy")
X_TS_PATH = osp.join(MAIN_DIR, "X_ts.npy")
Y_PATH = osp.join(MAIN_DIR, "Y.npy")
SPLIT_TEST_PATH = osp.join(MAIN_DIR, "split_test.npy")
SPLIT_CV_PATH = osp.join(MAIN_DIR, "split_cv.npy")
SSL_SPLITS_DIR = osp.join(MAIN_DIR, "ssl_splits")
HARMONIZED_X_PATH = osp.join(MAIN_DIR, "harmonized_X.npy")
Loading

0 comments on commit 9bdd01b

Please sign in to comment.