diff --git a/README.md b/README.md index 18581d4..7c2b1df 100644 --- a/README.md +++ b/README.md @@ -1,76 +1,34 @@ -# Semi-supervised learning with data harmonisation for biomarker discovery from resting state fMRI +# SHRED -Code for SHRED-II is provided in this repository. +This repo contains the code for various variants of the SHRED architecture. -Models were developed in Python using Pytorch v1.8.0. The experiments were performed on an Nvidia P100 GPU. -- Information on sensitivity regarding parameter changes can be found in `./figures/`. -- Expected runtime: Depending on the size of the chosen site, a full run of 10 seeds x 5 folds for a single site can take between 1 to 3 minutes. -- Memory footprint: ~4GB of GPU memory - -Data is available for download at the following links: [ABIDE](http://preprocessed-connectomes-project.org/abide/), [ADHD](http://preprocessed-connectomes-project.org/adhd200/). +A previous version of the repo, linked to our MICCAI 2022 submission, can be found on the branch `miccai_2022` [link](https://github.com/SCSE-Biomedical-Computing-Group/SHRED/tree/miccai_2022). +Data is available for download at the following links: [SchizConnect](http://schizconnect.org), [UCLA](https://openneuro.org/datasets/ds000030/versions/00016). +Some sites in SchizConnect seem to be down for some time. ## Environment Setup -1. Create and activate new Anaconda environment +1. Create and activate new conda environment conda create -n python=3.8 conda activate -2. Run ``setup.sh`` +2. Run `setup.sh` chmod u+x ./setup.sh ./setup.sh + +## Setup for a new dataset -## Dataset Preparation - -1. Process the raw fMRI data to obtain a functional connectivity matrix for each subject scan. The functional connectivity matrices should be saved as ``.npy`` files ([example](dataset/ABIDE/processed_corr_mat/)). There are no specific requirements on where the files should be saved at. - -2. Prepare a CSV file ([example](dataset/ABIDE/meta.csv)) which contains the required columns below: - - - ``SUBJECT``: A unique identifier for different subjects - - ``AGE``: The age of subjects when the fMRI scan is acquired - - ``GENDER``: The gender of subjects - - ``DISEASE_LABEL``: The label for disease classification (0 represents cognitive normal subjects, 1 represents diseased subjects) - - ``SITE``: The site in which the subject scan is acquired - - ``FC_MATRIX_PATH``: The paths to the ``.npy`` files that store the functional connectivity matrices of subjects. This path can either be absolute local path or relative path from the directory of the CSV file. - -## Run the Code - -1. Modify the ``config.yml`` file ([example](src/config.yml)) or create a new ``config.yml`` file as the input to the ``main.py`` script. The ``config.yml`` file contains the necessary arguments required to run the main script. - - - ``output_directory``: The directory in which the results should be stored at - - ``model_name``: The name to be assigned to the model - - ``model_params``: The parameters for initializing the EDC-VAE model, including - - ``hidden_size``: The number of hidden nodes for encoder and decoder. - - ``emb_size``: The dimension of the latent space representation output by the encoder. - - ``clf_hidden_1``: The number of hidden nodes in the first hidden layer of classifier. - - ``clf_hidden_2``: The number of hidden nodes in the second hidden layer of classifier. - - ``dropout``: The amount of dropout during training. - - ``optim_params``: The parameters for initializing Adam optimizer - - ``lr``: The learning rate - - ``l2_reg``: The L2 regularization - - ``hyperparameters``: Additional hyperparameters when training EDC-VAE model - - ``ll_loss``: The weightage of VAE likelihood loss - - ``kl_loss``: The weightage of KL divergence - - ``dataset_path``: Path to the CSV prepared during dataset preparation stage. This path can be an absolute path, or a relative path from the directory of ``main.py`` - - ``dataset_name``: The name to be assigned to the dataset - - ``seeds``: A list of seeds to iterate through - - ``num_fold``: The number of folds for cross validation - - ``ssl``: A boolean, indicating whether unlabeled data should be used to train the EDC-VAE model - - ``harmonize``: A boolean, indicating whether ComBat harmonization should be performed prior to model training. - - ``labeled_sites``: The site used as labeled data, when ``null``, all sites are used as labeled data. - - ``device``: The GPU to be used, ``-1`` means to use CPU only - - ``verbose``: A boolean, indicating Whether to display training progress messages - - ``max_epoch``: The maximum number of epochs - - ``patience``: The number of epochs without improvement for early stopping - - ``save_model``: A boolean, indicating whether to save EDC-VAE's state_dict. +1. Prepare dataset -2. Run the main script + - Create a new folder under `./src` with the dataset name (see `./Schiz` for reference) and modify the setup and config files. + - Edit `__init__.py` to specify how to retrieve site, age and gender. Labelling standards too, if applicable. + - Add dataset to `DataloaderBase` class (`_get_indices()` too) and `Dataset` class in `./src/data.py` - cd src - python main.py --config +2. Create `.yml` files in `config_template` to define model hyperparameters used and training settings. More details about the YAML files can be found in the `miccai_2022` branch. -3. Load a saved model and perform inference +3. Train the model (and any other models - specify in the `.yml` file) using `single_stage_framework.py`. - python evaluate.py + python single_stage_framework.py --config config_templates/individual/SHRED-III.yml diff --git a/dataset/ABIDE/meta.csv b/dataset/ABIDE/meta.csv deleted file mode 100644 index b88083e..0000000 --- a/dataset/ABIDE/meta.csv +++ /dev/null @@ -1,21 +0,0 @@ -SUBJECT,AGE,GENDER,DISEASE_LABEL,SITE,FC_MATRIX_PATH -00001,1,1,1,NYU,processed_corr_mat/NYU_0000001_power.npy -00002,2,1,0,NYU,processed_corr_mat/NYU_0000002_power.npy -00003,3,0,1,NYU,processed_corr_mat/NYU_0000003_power.npy -00004,4,0,0,NYU,processed_corr_mat/NYU_0000004_power.npy -00005,5,1,1,NYU,processed_corr_mat/NYU_0000005_power.npy -00006,6,0,1,NYU,processed_corr_mat/NYU_0000006_power.npy -00007,7,1,0,NYU,processed_corr_mat/NYU_0000007_power.npy -00008,8,0,1,NYU,processed_corr_mat/NYU_0000008_power.npy -00009,9,0,0,NYU,processed_corr_mat/NYU_0000009_power.npy -00010,10,1,0,NYU,processed_corr_mat/NYU_0000010_power.npy -00011,1,1,1,USM,processed_corr_mat/USM_0000011_power.npy -00012,2,0,1,USM,processed_corr_mat/USM_0000012_power.npy -00013,3,1,0,USM,processed_corr_mat/USM_0000013_power.npy -00014,4,0,0,USM,processed_corr_mat/USM_0000014_power.npy -00015,5,1,1,USM,processed_corr_mat/USM_0000015_power.npy -00016,6,1,0,USM,processed_corr_mat/USM_0000016_power.npy -00017,7,0,1,USM,processed_corr_mat/USM_0000017_power.npy -00018,8,0,0,USM,processed_corr_mat/USM_0000018_power.npy -00019,9,1,0,USM,processed_corr_mat/USM_0000019_power.npy -00020,10,0,1,USM,processed_corr_mat/USM_0000020_power.npy \ No newline at end of file diff --git a/dataset/ABIDE/processed_corr_mat/NYU_0000001_power.npy b/dataset/ABIDE/processed_corr_mat/NYU_0000001_power.npy deleted file mode 100644 index 230b3b6..0000000 Binary files a/dataset/ABIDE/processed_corr_mat/NYU_0000001_power.npy and /dev/null differ diff --git a/dataset/ABIDE/processed_corr_mat/NYU_0000002_power.npy b/dataset/ABIDE/processed_corr_mat/NYU_0000002_power.npy deleted file mode 100644 index 0e2acee..0000000 Binary files a/dataset/ABIDE/processed_corr_mat/NYU_0000002_power.npy and /dev/null differ diff --git a/dataset/ABIDE/processed_corr_mat/NYU_0000003_power.npy b/dataset/ABIDE/processed_corr_mat/NYU_0000003_power.npy deleted file mode 100644 index a41af23..0000000 Binary files a/dataset/ABIDE/processed_corr_mat/NYU_0000003_power.npy and /dev/null differ diff --git a/dataset/ABIDE/processed_corr_mat/NYU_0000004_power.npy b/dataset/ABIDE/processed_corr_mat/NYU_0000004_power.npy deleted file mode 100644 index 402bc58..0000000 Binary files a/dataset/ABIDE/processed_corr_mat/NYU_0000004_power.npy and /dev/null differ diff --git a/dataset/ABIDE/processed_corr_mat/NYU_0000005_power.npy b/dataset/ABIDE/processed_corr_mat/NYU_0000005_power.npy deleted file mode 100644 index 9007271..0000000 Binary files a/dataset/ABIDE/processed_corr_mat/NYU_0000005_power.npy and /dev/null differ diff --git a/dataset/ABIDE/processed_corr_mat/NYU_0000006_power.npy b/dataset/ABIDE/processed_corr_mat/NYU_0000006_power.npy deleted file mode 100644 index 3fc703a..0000000 Binary files a/dataset/ABIDE/processed_corr_mat/NYU_0000006_power.npy and /dev/null differ diff --git a/dataset/ABIDE/processed_corr_mat/NYU_0000007_power.npy b/dataset/ABIDE/processed_corr_mat/NYU_0000007_power.npy deleted file mode 100644 index c7d4f57..0000000 Binary files a/dataset/ABIDE/processed_corr_mat/NYU_0000007_power.npy and /dev/null differ diff --git a/dataset/ABIDE/processed_corr_mat/NYU_0000008_power.npy b/dataset/ABIDE/processed_corr_mat/NYU_0000008_power.npy deleted file mode 100644 index d0aedf9..0000000 Binary files a/dataset/ABIDE/processed_corr_mat/NYU_0000008_power.npy and /dev/null differ diff --git a/dataset/ABIDE/processed_corr_mat/NYU_0000009_power.npy b/dataset/ABIDE/processed_corr_mat/NYU_0000009_power.npy deleted file mode 100644 index 12703d4..0000000 Binary files a/dataset/ABIDE/processed_corr_mat/NYU_0000009_power.npy and /dev/null differ diff --git a/dataset/ABIDE/processed_corr_mat/NYU_0000010_power.npy b/dataset/ABIDE/processed_corr_mat/NYU_0000010_power.npy deleted file mode 100644 index f498bf8..0000000 Binary files a/dataset/ABIDE/processed_corr_mat/NYU_0000010_power.npy and /dev/null differ diff --git a/dataset/ABIDE/processed_corr_mat/USM_0000011_power.npy b/dataset/ABIDE/processed_corr_mat/USM_0000011_power.npy deleted file mode 100644 index 7653ded..0000000 Binary files a/dataset/ABIDE/processed_corr_mat/USM_0000011_power.npy and /dev/null differ diff --git a/dataset/ABIDE/processed_corr_mat/USM_0000012_power.npy b/dataset/ABIDE/processed_corr_mat/USM_0000012_power.npy deleted file mode 100644 index c75986b..0000000 Binary files a/dataset/ABIDE/processed_corr_mat/USM_0000012_power.npy and /dev/null differ diff --git a/dataset/ABIDE/processed_corr_mat/USM_0000013_power.npy b/dataset/ABIDE/processed_corr_mat/USM_0000013_power.npy deleted file mode 100644 index 7cfef09..0000000 Binary files a/dataset/ABIDE/processed_corr_mat/USM_0000013_power.npy and /dev/null differ diff --git a/dataset/ABIDE/processed_corr_mat/USM_0000014_power.npy b/dataset/ABIDE/processed_corr_mat/USM_0000014_power.npy deleted file mode 100644 index e695f46..0000000 Binary files a/dataset/ABIDE/processed_corr_mat/USM_0000014_power.npy and /dev/null differ diff --git a/dataset/ABIDE/processed_corr_mat/USM_0000015_power.npy b/dataset/ABIDE/processed_corr_mat/USM_0000015_power.npy deleted file mode 100644 index a517e2e..0000000 Binary files a/dataset/ABIDE/processed_corr_mat/USM_0000015_power.npy and /dev/null differ diff --git a/dataset/ABIDE/processed_corr_mat/USM_0000016_power.npy b/dataset/ABIDE/processed_corr_mat/USM_0000016_power.npy deleted file mode 100644 index 93d670f..0000000 Binary files a/dataset/ABIDE/processed_corr_mat/USM_0000016_power.npy and /dev/null differ diff --git a/dataset/ABIDE/processed_corr_mat/USM_0000017_power.npy b/dataset/ABIDE/processed_corr_mat/USM_0000017_power.npy deleted file mode 100644 index 03cc00c..0000000 Binary files a/dataset/ABIDE/processed_corr_mat/USM_0000017_power.npy and /dev/null differ diff --git a/dataset/ABIDE/processed_corr_mat/USM_0000018_power.npy b/dataset/ABIDE/processed_corr_mat/USM_0000018_power.npy deleted file mode 100644 index 34f7318..0000000 Binary files a/dataset/ABIDE/processed_corr_mat/USM_0000018_power.npy and /dev/null differ diff --git a/dataset/ABIDE/processed_corr_mat/USM_0000019_power.npy b/dataset/ABIDE/processed_corr_mat/USM_0000019_power.npy deleted file mode 100644 index d16c135..0000000 Binary files a/dataset/ABIDE/processed_corr_mat/USM_0000019_power.npy and /dev/null differ diff --git a/dataset/ABIDE/processed_corr_mat/USM_0000020_power.npy b/dataset/ABIDE/processed_corr_mat/USM_0000020_power.npy deleted file mode 100644 index 7ddaf30..0000000 Binary files a/dataset/ABIDE/processed_corr_mat/USM_0000020_power.npy and /dev/null differ diff --git a/dataset/README.md b/dataset/README.md deleted file mode 100644 index c602c03..0000000 --- a/dataset/README.md +++ /dev/null @@ -1 +0,0 @@ -Data provided in this folder is randomly generated. \ No newline at end of file diff --git a/figures/README.md b/figures/README.md deleted file mode 100644 index 2bdb64d..0000000 --- a/figures/README.md +++ /dev/null @@ -1 +0,0 @@ -For `hyperparameters__ch_loss_accuracy_boxplot.png`, 1.0 was used instead as the number of samples for 0.001 was too low (4). \ No newline at end of file diff --git a/figures/hyperparameters__ch_loss_accuracy_boxplot.png b/figures/hyperparameters__ch_loss_accuracy_boxplot.png deleted file mode 100644 index 532956e..0000000 Binary files a/figures/hyperparameters__ch_loss_accuracy_boxplot.png and /dev/null differ diff --git a/figures/hyperparameters__kl_loss_accuracy_boxplot.png b/figures/hyperparameters__kl_loss_accuracy_boxplot.png deleted file mode 100644 index 27392bc..0000000 Binary files a/figures/hyperparameters__kl_loss_accuracy_boxplot.png and /dev/null differ diff --git a/figures/hyperparameters__rc_loss_accuracy_boxplot.png b/figures/hyperparameters__rc_loss_accuracy_boxplot.png deleted file mode 100644 index 0722ad6..0000000 Binary files a/figures/hyperparameters__rc_loss_accuracy_boxplot.png and /dev/null differ diff --git a/figures/model_params__dropout_accuracy_boxplot.png b/figures/model_params__dropout_accuracy_boxplot.png deleted file mode 100644 index 07d769a..0000000 Binary files a/figures/model_params__dropout_accuracy_boxplot.png and /dev/null differ diff --git a/figures/model_params__emb_size_accuracy_boxplot.png b/figures/model_params__emb_size_accuracy_boxplot.png deleted file mode 100644 index 071647f..0000000 Binary files a/figures/model_params__emb_size_accuracy_boxplot.png and /dev/null differ diff --git a/figures/model_params__hidden_size_accuracy_boxplot.png b/figures/model_params__hidden_size_accuracy_boxplot.png deleted file mode 100644 index 0b2af7a..0000000 Binary files a/figures/model_params__hidden_size_accuracy_boxplot.png and /dev/null differ diff --git a/figures/optim_params__lr_accuracy_boxplot.png b/figures/optim_params__lr_accuracy_boxplot.png deleted file mode 100644 index f76d4ba..0000000 Binary files a/figures/optim_params__lr_accuracy_boxplot.png and /dev/null differ diff --git a/saved_model/ABIDE_VAE-FFN_0_0_1645419832.pt b/saved_model/ABIDE_VAE-FFN_0_0_1645419832.pt deleted file mode 100644 index df49829..0000000 Binary files a/saved_model/ABIDE_VAE-FFN_0_0_1645419832.pt and /dev/null differ diff --git a/src/Schiz/README.md b/src/Schiz/README.md new file mode 100644 index 0000000..1ec3b5d --- /dev/null +++ b/src/Schiz/README.md @@ -0,0 +1,15 @@ +# Schiz + +Data collected from SchizConnect (NMorphCH, COBRE), UCLA and a private dataset. + +## Setup Dataset Guide + +1. Make sure that the path in ``setup.py`` is correct + + main_dir = + corr_mat_dir = + phenotypics_path = + +2. Run ``setup.py`` + + python setup.py \ No newline at end of file diff --git a/src/Schiz/__init__.py b/src/Schiz/__init__.py new file mode 100644 index 0000000..0161cfb --- /dev/null +++ b/src/Schiz/__init__.py @@ -0,0 +1,66 @@ +import os +import os +import sys +import numpy as np +import pandas as pd + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from schiz_config import * + + +def load_data_fmri(harmonized=False, time_series=False): + """ + Inputs + site_id: str or None + - if specified, splits will only contains index of subjects from the specified site + harmonized: bool + - whether or not to return combat harmonized X + + Returns + X: np.ndarray with N subject samples, each sample consists of a ROI x ROI correlation matrix + Y: np.ndarray with N subject samples, each sample has a one-hot encoded label (0: Normal, 1: Diseased) + """ + if harmonized: + X = np.load(HARMONIZED_X_PATH) + else: + X = np.load(X_PATH) + Y = np.load(Y_PATH) + if not time_series: + return X, Y + raise NotImplementedError + + +def get_splits(site_id=None, test=False): + if site_id is None: + path = SPLIT_TEST_PATH if test else SPLIT_CV_PATH + else: + path = "{}_test.npy".format(site_id) if test else "{}_cv.npy".format(site_id) + path = os.path.join(SSL_SPLITS_DIR, path) + splits = np.load(path, allow_pickle=True) + return splits + + +def get_ages_and_genders(): + """ + ages: np.array of float representing the age of subject when the scan is obtained + gender: np.array of int representing the subject's gender + - 0: Female + - 1: Male + """ + meta_df = pd.read_csv(META_CSV_PATH) + ages = np.array(meta_df["age"]) + genders = np.array(meta_df["sex"]) + return ages, genders + + +def get_sites(): + meta_df = pd.read_csv(META_CSV_PATH) + sites = np.array(meta_df["study"]) + return sites + + +def load_meta_df(): + df = pd.read_csv(META_CSV_PATH) + df["FILE_PATH"] = df["FILE_PATH"].apply(eval) + return df + diff --git a/src/Schiz/schiz_config.py b/src/Schiz/schiz_config.py new file mode 100644 index 0000000..87a5e6a --- /dev/null +++ b/src/Schiz/schiz_config.py @@ -0,0 +1,15 @@ +import os.path as osp + + +__dir__ = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__)))) + + +MAIN_DIR = osp.abspath(osp.join(__dir__, "data", "Schiz")) +META_CSV_PATH = osp.join(MAIN_DIR, "meta.csv") +X_PATH = osp.join(MAIN_DIR, "X.npy") +X_TS_PATH = osp.join(MAIN_DIR, "X_ts.npy") +Y_PATH = osp.join(MAIN_DIR, "Y.npy") +SPLIT_TEST_PATH = osp.join(MAIN_DIR, "split_test.npy") +SPLIT_CV_PATH = osp.join(MAIN_DIR, "split_cv.npy") +SSL_SPLITS_DIR = osp.join(MAIN_DIR, "ssl_splits") +HARMONIZED_X_PATH = osp.join(MAIN_DIR, "harmonized_X.npy") \ No newline at end of file diff --git a/src/Schiz/setup.py b/src/Schiz/setup.py new file mode 100644 index 0000000..f6031a7 --- /dev/null +++ b/src/Schiz/setup.py @@ -0,0 +1,220 @@ +import os +import time +import warnings +import numpy as np +import pandas as pd +from schiz_config import * +from contextlib import contextmanager +from neuroCombat import neuroCombat +from scipy.spatial.distance import squareform +from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit + + +@contextmanager +def log_time(description): + print("[{}] started".format(description)) + start = time.time() + yield + end = time.time() + print("[{}] completed in {:.3f} s".format(description, end - start)) + + +def get_processed_corr_mat_file_ids(corr_mat_dir): + file_ids = [] + for _, _, files in os.walk(corr_mat_dir): + for filename in files: + if filename.endswith(".npy"): + fname_no_ext = filename.split('.')[0] + fname_no_atlas = '_'.join(fname_no_ext.split('_')[:-1]) + file_ids.append(fname_no_atlas) + return file_ids + + +def extract_data(corr_mat_dir, meta_csv_path): + + meta_df = pd.read_csv(meta_csv_path) + file_ids = get_processed_corr_mat_file_ids(corr_mat_dir) + + meta_df["PROCESSED"] = meta_df["id"].apply(lambda x: x in file_ids) + processed_df = meta_df[meta_df["PROCESSED"]].sort_values("id") + processed_df = processed_df.drop("PROCESSED", axis=1) + + processed_df["FILE_PATH"] = processed_df["id"].apply( + lambda x: os.path.join(corr_mat_dir, "{}_power.npy".format(x)) + ) + processed_df["TIME_SERIES_PATH"] = processed_df["id"].apply( + lambda x: os.path.join(corr_mat_dir, "../processed_ts/{}_power_TS.npy".format(x)) + ) + processed_df.to_csv(META_CSV_PATH, header=True, index=False) + + X = np.array([np.load(fname) for fname in processed_df["FILE_PATH"]]) + X = np.nan_to_num(X) + X_ts = np.empty(X.shape[0], dtype=object) + for i, fname in enumerate(processed_df["TIME_SERIES_PATH"]): + X_ts[i] = np.nan_to_num(np.load(fname)) + Y = np.array(processed_df["dx"] == 'SZ') + Y_onehot = np.eye(2)[Y.astype(int)] + np.save(X_PATH, X) # (N, ROI, ROI) + np.save(X_TS_PATH, X_ts) + np.save(Y_PATH, Y_onehot) # (N, 2) + return processed_df, X, Y + + + +def split_traintest_sbj(Y, test_split_frac, seed): + X = np.zeros(Y.shape[0]) + np.random.seed(seed) + sss = StratifiedShuffleSplit( + n_splits=1, test_size=test_split_frac, random_state=seed + ) + train_index, test_index = next(sss.split(X, Y)) + return train_index, test_index + + +def split_kfoldcv_sbj(Y, n, seed): + X = np.zeros(Y.shape[0]) + np.random.seed(seed) + skf_group = StratifiedKFold( + n_splits=n, shuffle=True, random_state=seed + ) + result = [] + for train_index, test_index in skf_group.split(X, Y): + result.append((train_index, test_index)) + return result + + +def generate_splits(Y, test_split_frac=0.2, kfold_n_splits=5, test=True): + """ + splits: np.ndarray with dimension 100 x 5 x 2 + - test indices of seed n = splits[n][0] + - the train and val indices of seed n, fold k = splits[n][1][k][0] and splits[n][1][k][1] + """ + splits = [] + for seed in range(100): + if test: + tuning_idx, test_idx = split_traintest_sbj(Y, test_split_frac, seed) + else: + tuning_idx, test_idx = np.arange(Y.shape[0]), np.array([]) + Y_tuning = Y[tuning_idx] + folds = split_kfoldcv_sbj(Y_tuning, kfold_n_splits, seed) + train_val_idx = [] + for tuning_train_idx, tuning_val_idx in folds: + train_idx = tuning_idx[tuning_train_idx] + val_idx = tuning_idx[tuning_val_idx] + assert len(set(train_idx) & set(val_idx)) == 0 + assert len(set(train_idx) & set(test_idx)) == 0 + assert len(set(val_idx) & set(test_idx)) == 0 + train_val_idx.append(np.array([train_idx, val_idx], dtype=object)) + train_val_idx = np.array(train_val_idx) + split = np.empty(2, dtype=object) + split[0] = test_idx + split[1] = train_val_idx + splits.append(split) + splits = np.array(splits, dtype=object) + return splits + + +def generate_ssl_splits(Y, real_idx, test_split_frac=0.2, kfold_n_splits=5, test=True): + """ + splits: np.ndarray with dimension 100 x 5 x 2 + - test indices of seed n = splits[n][0] + - the train and val indices of seed n, fold k = splits[n][1][k][0] and splits[n][1][k][1] + """ + warnings.filterwarnings("ignore") + splits = [] + for seed in range(100): + if test: + tuning_idx, test_idx = split_traintest_sbj(Y[real_idx], test_split_frac, seed) + tuning_idx = real_idx[tuning_idx] + test_idx = real_idx[test_idx] + else: + tuning_idx = real_idx + test_idx = np.array([]) + Y_tuning = Y[tuning_idx] + folds = split_kfoldcv_sbj(Y_tuning, kfold_n_splits, seed) + train_val_idx = [] + for tuning_train_idx, tuning_val_idx in folds: + train_idx = tuning_idx[tuning_train_idx] + val_idx = tuning_idx[tuning_val_idx] + assert len(set(train_idx) & set(val_idx)) == 0 + assert len(set(train_idx) & set(test_idx)) == 0 + assert len(set(val_idx) & set(test_idx)) == 0 + train_val_idx.append(np.array([train_idx, val_idx], dtype=object)) + train_val_idx = np.array(train_val_idx) + split = np.empty(2, dtype=object) + split[0] = test_idx + split[1] = train_val_idx + splits.append(split) + splits = np.array(splits, dtype=object) + return splits + + +def split(Y): + splits = generate_splits(Y, test=True) + np.save(SPLIT_TEST_PATH, splits) # (100, 5, 2) + splits = generate_splits(Y, test=False) + np.save(SPLIT_CV_PATH, splits) # (100, 5, 2) + + +def ssl_split(meta_df, Y): + for site_id in np.unique(meta_df['study']): + with log_time("generate ssl split seed for site {}".format(site_id)) as lt: + try: + site_idx = np.argwhere(meta_df['study'].values == site_id).flatten() + splits = generate_ssl_splits(Y, site_idx, test=True) + np.save(os.path.join(SSL_SPLITS_DIR, "{}_test.npy".format(site_id)), splits) # (100, 5, 2) + splits = generate_ssl_splits(Y, site_idx, test=False) + np.save(os.path.join(SSL_SPLITS_DIR, "{}_cv.npy".format(site_id)), splits) # (100, 5, 2) + except Exception as e: + print("study: {}, ERROR: {}".format(site_id, e)) + # if site_id != "CMU": + raise e + # print("study: {}, ERROR: {}".format(site_id, e)) + + +def corr_mx_flatten(X): + """ + returns upper triangluar matrix of each sample in X + X.shape == (num_sample, num_feature, num_feature) + X_flattened.shape == (num_sample, num_feature * (num_feature - 1) / 2) + """ + upper_triangular_idx = np.triu_indices(X.shape[1], 1) + X_flattened = X[:, upper_triangular_idx[0], upper_triangular_idx[1]] + return X_flattened + + +def combat_harmonization(X, meta_df): + X = corr_mx_flatten(X) + covars = meta_df[["study", "age", "sex"]] + categorical_cols = ["sex"] + continuous_cols = ["age"] + batch_col = "study" + combat = neuroCombat( + dat=X.T, covars=covars, batch_col=batch_col, + categorical_cols=categorical_cols, + continuous_cols=continuous_cols, + ) + harmonized_X = combat["data"].T + harmonized_X = np.clip(harmonized_X, -1, 1) + harmonized_X = np.array([squareform(x) for x in harmonized_X]) + np.save(HARMONIZED_X_PATH, harmonized_X) + +if __name__ == "__main__": + + main_dir = '' # + corr_mat_dir = os.path.join(main_dir, "fmri", "processed_corr_mat") + meta_csv_path = os.path.join(main_dir, "meta", "meta.csv") + + if not os.path.exists(SSL_SPLITS_DIR): + os.makedirs(SSL_SPLITS_DIR) + + with log_time("extract metadata and correlation matrices") as lt: + meta_df, X, Y = extract_data(corr_mat_dir, meta_csv_path) + + with log_time("generate split seed for whole dataset") as lt: + split(Y) + + ssl_split(meta_df, Y) + + with log_time("neuroCombat") as lt: + combat_harmonization(X, meta_df) \ No newline at end of file diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..71859dc --- /dev/null +++ b/src/config.py @@ -0,0 +1,196 @@ +from __future__ import annotations +from collections import defaultdict +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, Optional, Sequence, Union + +import os.path as osp +from itertools import product + +__dir__ = osp.dirname(osp.dirname(osp.abspath(__file__))) +EXPERIMENT_DIR = osp.join(__dir__, "experiments") + + +@dataclass(frozen=True) +class RangeGenerator: + min: int + max: int + + def generate(self): + return list(range(self.min, self.max)) + + @staticmethod + def parse(range_cfg: Dict[str, Any]) -> RangeGenerator: + return RangeGenerator(**range_cfg) + + +@dataclass(frozen=True) +class ModelConfig: + all_models: Sequence[single_model] + + @dataclass(frozen=True) + class single_model: + model_name: str + model_params: Dict[str, Any] + optim_params: Dict[str, Any] = field(default_factory=dict) + hyperparameters: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return dict( + model_name=self.model_name, + model_params=self.model_params, + optim_params=self.optim_params, + hyperparameters=self.hyperparameters, + ) + + def generate(self): + return [cfg.to_dict() for cfg in self.all_models] + + @staticmethod + def parse(model_configs: Sequence[Dict[str, Any]]) -> ModelConfig: + all_models = [ModelConfig.single_model(**cfg) for cfg in model_configs] + return ModelConfig(all_models) + + +@dataclass(frozen=True) +class DataConfig: + all_data: Sequence[single_data] + + @dataclass(frozen=True) + class single_data: + dataset: str + labeled_sites: Sequence[Optional[Union[str, Sequence[str]]]] + unlabeled_sites: Sequence[Optional[Union[str, Sequence[str]]]] = field( + default=(None,) + ) + num_unlabeled: Sequence[Optional[Union[str, Sequence[str]]]] = field( + default=(None,) + ) + output_directory: Optional[str] = field(default=None) + + def generate(self): + return [ + dict( + dataset=cfg.dataset, + labeled_sites=labeled_sites, + unlabeled_sites=unlabeled_sites, + num_unlabeled=num_unlabeled, + output_directory=cfg.output_directory, + ) + for cfg in self.all_data + for labeled_sites in cfg.labeled_sites + for unlabeled_sites in cfg.unlabeled_sites + for num_unlabeled in cfg.num_unlabeled + ] + + @staticmethod + def parse(data_configs: Sequence[Dict[str, Any]]) -> DataConfig: + return DataConfig( + [DataConfig.single_data(**cfg) for cfg in data_configs] + ) + + +@dataclass(frozen=True) +class ExperimentSettings: + all_settings: Sequence[single_setting] + + @dataclass(frozen=True) + class single_setting: + ssl: bool = field(default=False) + harmonize: bool = field(default=False) + validation: bool = field(default=False) + + def to_dict(self) -> Dict[str, bool]: + return dict( + ssl=self.ssl, + harmonize=self.harmonize, + validation=self.validation, + ) + + def generate(self): + return [cfg.to_dict() for cfg in self.all_settings] + + @staticmethod + def parse(exp_settings: Sequence[Dict[str, bool]]) -> ExperimentSettings: + return ExperimentSettings( + [ExperimentSettings.single_setting(**cfg) for cfg in exp_settings] + ) + + +@dataclass(frozen=True) +class ProcessConfig: + device: int = field(default=-1) + verbose: bool = field(default=0) + max_epoch: int = field(default=1000) + patience: int = field(default=1000) + dataloader_num_process: int = field(default=1) + save_model_condition: Sequence[Dict[str, Any]] = field(default_factory=list) + + def match_save_model_condition(self, config: Dict[str, Any]): + if not self.save_model_condition: + return True + for condition in self.save_model_condition: + matched = True + for key, value in condition.items(): + if key not in config: + matched = False + elif value != config[key]: + matched = False + if not matched: + break + if matched: + return True + return False + + def update(self, config: Dict[str, Any]): + config["device"] = self.device + config["verbose"] = self.verbose + config["max_epoch"] = self.max_epoch + config["patience"] = self.patience + config["dataloader_num_process"] = self.dataloader_num_process + config["save_model"] = self.match_save_model_condition(config) + return config + + +@dataclass(frozen=True) +class FrameworkConfigParser: + seed: RangeGenerator + fold: RangeGenerator + model: ModelConfig + data: DataConfig + experiment_settings: ExperimentSettings + process: ProcessConfig + + def generate(self): + for model, data, exp_setting in product( + self.model.generate(), + self.data.generate(), + self.experiment_settings.generate(), + ): + config = { + "seed": self.seed.generate(), + "fold": self.fold.generate(), + **model, + **data, + **exp_setting, + } + config = self.process.update(config) + yield config + + @staticmethod + def parse( + seed: Dict[str, int], + fold: Dict[str, Any], + model: Sequence[Dict[str, Any]], + data: Sequence[Dict[str, Any]], + experiment_settings: Sequence[Dict[str, bool]], + process: Dict[str, Any], + ) -> FrameworkConfigParser: + return FrameworkConfigParser( + seed=RangeGenerator.parse(seed), + fold=RangeGenerator.parse(fold), + model=ModelConfig.parse(model), + data=DataConfig.parse(data), + experiment_settings=ExperimentSettings.parse(experiment_settings), + process=ProcessConfig(**process), + ) diff --git a/src/config.yml b/src/config.yml deleted file mode 100644 index deaffa9..0000000 --- a/src/config.yml +++ /dev/null @@ -1,36 +0,0 @@ -output_directory: ../results -model_name: EDC_VAE -model_params: - hidden_size: 32 - emb_size: 32 - clf_hidden_1: 0 - clf_hidden_2: 0 - dropout: 0.2 -optim_params: - lr: 0.002 - l2_reg: 0.001 -hyperparameters: - ll_loss: 0.0001 - kl_loss: 0.001 -dataset_path: ../dataset/ABIDE/meta.csv -dataset_name: ABIDE -seeds: - - 0 - - 1 - - 2 - - 3 - - 4 - - 5 - - 6 - - 7 - - 8 - - 9 -num_fold: 5 -ssl: true -harmonize: true -labeled_sites: NYU -device: 2 -verbose: true -max_epoch: 1000 -patience: 1000 -save_model: true \ No newline at end of file diff --git a/src/config_templates/individual/FFN.yml b/src/config_templates/individual/FFN.yml new file mode 100644 index 0000000..cc3a6b0 --- /dev/null +++ b/src/config_templates/individual/FFN.yml @@ -0,0 +1,43 @@ +seed: + min: 0 + max: 10 +fold: + min: 0 + max: 5 +model: + - + model_name: FFN + model_params: + hidden_1: 75 + hidden_2: 50 + hidden_3: 30 + dropout: 0.5 + optim_params: + lr: 0.0001 + l2_reg: 0.001 +data: + - + output_directory: ../.archive/Schiz_INDIVIDUAL + dataset: Schiz + labeled_sites: + - h + - nmorph + - ucla + - cobre +experiment_settings: + - + ssl: false + harmonize: false + validation: false + - + ssl: false + harmonize: true + validation: false +process: + device: 3 + verbose: false + max_epoch: 1000 + patience: 1000 + save_model_condition: + - + dataset: Schiz \ No newline at end of file diff --git a/src/config_templates/individual/SHRED-I.yml b/src/config_templates/individual/SHRED-I.yml new file mode 100644 index 0000000..1e86d42 --- /dev/null +++ b/src/config_templates/individual/SHRED-I.yml @@ -0,0 +1,45 @@ +seed: + min: 10 + max: 11 +fold: + min: 0 + max: 5 +model: + - + model_name: SHRED-I + model_params: + hidden_size: 64 + emb_size: 16 + clf_hidden_1: 0 + clf_hidden_2: 0 + dropout: 0.1 + hyperparameters: + rc_loss: 0.00001 + kl_loss: 0.001 + ch_loss: 0.01 + alpha_loss: true + optim_params: + lr: 0.0005 + l2_reg: 0.001 +data: + - + output_directory: ../.archive/Schiz_INDIVIDUAL + dataset: Schiz + labeled_sites: + - h + - nmorph + - ucla + - cobre +experiment_settings: + - + ssl: true + harmonize: false + validation: false +process: + device: 0 + verbose: false + max_epoch: 1000 + patience: 1000 + save_model_condition: + - + dataset: Schiz diff --git a/src/config_templates/individual/SHRED-III.yml b/src/config_templates/individual/SHRED-III.yml new file mode 100644 index 0000000..8c3c683 --- /dev/null +++ b/src/config_templates/individual/SHRED-III.yml @@ -0,0 +1,45 @@ +seed: + min: 0 + max: 10 +fold: + min: 0 + max: 5 +model: + - + model_name: SHRED-III + model_params: + hidden_size: 32 + emb_size: 32 + clf_hidden_1: 0 + clf_hidden_2: 0 + dropout: 0.2 + hyperparameters: + rc_loss: 0.00001 + kl_loss: 0.001 + ch_loss: 0.0001 + alpha_loss: true + optim_params: + lr: 0.002 + l2_reg: 0.001 +data: + - + output_directory: ../.archive/Schiz_INDIVIDUAL + dataset: Schiz + labeled_sites: + - h + - nmorph + - ucla + - cobre +experiment_settings: + - + ssl: true + harmonize: true + validation: false +process: + device: 1 + verbose: false + max_epoch: 1000 + patience: 1000 + save_model_condition: + - + dataset: Schiz diff --git a/src/config_templates/individual/VAE-FFN.yml b/src/config_templates/individual/VAE-FFN.yml new file mode 100644 index 0000000..05bf051 --- /dev/null +++ b/src/config_templates/individual/VAE-FFN.yml @@ -0,0 +1,55 @@ +seed: + min: 0 + max: 10 +fold: + min: 0 + max: 5 +model: + - + model_name: VAE-FFN + model_params: + hidden_size: 32 + emb_size: 16 + clf_hidden_1: 0 + clf_hidden_2: 0 + dropout: 0.1 + hyperparameters: + rc_loss: 0.0001 + kl_loss: 0.001 + optim_params: + lr: 0.001 + l2_reg: 0.001 +data: + - + output_directory: ../.archive/Schiz_INDIVIDUAL + dataset: Schiz + labeled_sites: + - h + - nmorph + - ucla + - cobre +experiment_settings: + - + ssl: true + harmonize: true + validation: false + - + ssl: false + harmonize: false + validation: false + - + ssl: false + harmonize: true + validation: false + - + ssl: true + harmonize: false + validation: false +process: + device: 1 + verbose: false + max_epoch: 1000 + patience: 1000 + save_model_condition: + - + dataset: Schiz \ No newline at end of file diff --git a/src/config_templates/whole/FFN.yml b/src/config_templates/whole/FFN.yml new file mode 100644 index 0000000..4527107 --- /dev/null +++ b/src/config_templates/whole/FFN.yml @@ -0,0 +1,44 @@ +seed: + min: 0 + max: 10 +fold: + min: 0 + max: 5 +model: + - + model_name: FFN + model_params: + hidden_1: 75 + hidden_2: 50 + hidden_3: 30 + dropout: 0.5 + optim_params: + lr: 0.0001 + l2_reg: 0.001 +data: + - + output_directory: ../.archive/Schiz_WHOLE + dataset: Schiz + labeled_sites: + - + - h + - nmorph + - ucla + - cobre +experiment_settings: + - + ssl: false + harmonize: false + validation: false + - + ssl: false + harmonize: true + validation: false +process: + device: 3 + verbose: false + max_epoch: 1000 + patience: 1000 + save_model_condition: + - + dataset: Schiz \ No newline at end of file diff --git a/src/config_templates/whole/SHRED-I.yml b/src/config_templates/whole/SHRED-I.yml new file mode 100644 index 0000000..e1ad949 --- /dev/null +++ b/src/config_templates/whole/SHRED-I.yml @@ -0,0 +1,46 @@ +seed: + min: 0 + max: 10 +fold: + min: 0 + max: 5 +model: + - + model_name: SHRED-I + model_params: + hidden_size: 32 + emb_size: 32 + clf_hidden_1: 0 + clf_hidden_2: 0 + dropout: 0.2 + hyperparameters: + rc_loss: 0.0001 + kl_loss: 0.001 + ch_loss: 1.0 + alpha_loss: true + optim_params: + lr: 0.002 + l2_reg: 0.001 +data: + - + output_directory: ../.archive/Schiz_WHOLE + dataset: Schiz + labeled_sites: + - + - h + - nmorph + - ucla + - cobre +experiment_settings: + - + ssl: false + harmonize: false + validation: false +process: + device: 3 + verbose: false + max_epoch: 1000 + patience: 1000 + save_model_condition: + - + dataset: Schiz \ No newline at end of file diff --git a/src/config_templates/whole/SHRED-III.yml b/src/config_templates/whole/SHRED-III.yml new file mode 100644 index 0000000..6c52ad5 --- /dev/null +++ b/src/config_templates/whole/SHRED-III.yml @@ -0,0 +1,46 @@ +seed: + min: 0 + max: 10 +fold: + min: 0 + max: 5 +model: + - + model_name: SHRED-III + model_params: + hidden_size: 32 + emb_size: 32 + clf_hidden_1: 0 + clf_hidden_2: 0 + dropout: 0.2 + hyperparameters: + rc_loss: 0.00001 + kl_loss: 0.001 + ch_loss: 0.0001 + alpha_loss: true + optim_params: + lr: 0.002 + l2_reg: 0.001 +data: + - + output_directory: ../.archive/Schiz_WHOLE + dataset: Schiz + labeled_sites: + - + - h + - nmorph + - ucla + - cobre +experiment_settings: + - + ssl: false + harmonize: false + validation: false +process: + device: 1 + verbose: false + max_epoch: 1000 + patience: 1000 + save_model_condition: + - + dataset: Schiz diff --git a/src/config_templates/whole/VAE-FFN.yml b/src/config_templates/whole/VAE-FFN.yml new file mode 100644 index 0000000..718baf2 --- /dev/null +++ b/src/config_templates/whole/VAE-FFN.yml @@ -0,0 +1,48 @@ +seed: + min: 0 + max: 10 +fold: + min: 0 + max: 5 +model: + - + model_name: VAE-FFN + model_params: + hidden_size: 32 + emb_size: 16 + clf_hidden_1: 0 + clf_hidden_2: 0 + dropout: 0.1 + hyperparameters: + rc_loss: 0.0001 + kl_loss: 0.001 + optim_params: + lr: 0.001 + l2_reg: 0.001 +data: + - + output_directory: ../.archive/Schiz_WHOLE + dataset: Schiz + labeled_sites: + - + - h + - nmorph + - ucla + - cobre +experiment_settings: + - + ssl: false + harmonize: false + validation: false + - + ssl: false + harmonize: true + validation: false +process: + device: 3 + verbose: false + max_epoch: 1000 + patience: 1000 + save_model_condition: + - + dataset: Schiz diff --git a/src/data.py b/src/data.py index 2731a31..2311d36 100644 --- a/src/data.py +++ b/src/data.py @@ -1,203 +1,220 @@ -import os -import pandas as pd +from enum import Enum +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Dict, Optional, Sequence, Tuple, Union + import numpy as np -from neuroCombat import neuroCombat -from sklearn.model_selection import StratifiedKFold from sklearn.preprocessing import LabelEncoder -from typing import Dict, Generator, Optional, Sequence, Union import torch from torch_geometric.data import Data +from utils.data import corr_mx_flatten -def corr_mx_flatten(X: np.ndarray) -> np.ndarray: - """ - returns upper triangluar matrix of each sample in X - - option 1: - X.shape == (num_sample, num_feature, num_feature) - X_flattened.shape == (num_sample, num_feature * (num_feature - 1) / 2) - - option 2: - X.shape == (num_feature, num_feature) - X_flattend.shape == (num_feature * (num_feature - 1) / 2,) - """ - upper_triangular_idx = np.triu_indices(X.shape[1], 1) - if len(X.shape) == 3: - X = X[:, upper_triangular_idx[0], upper_triangular_idx[1]] - else: - X = X[upper_triangular_idx[0], upper_triangular_idx[1]] - return X - - -def combat_harmonization(X: np.ndarray, meta_df: pd.DataFrame) -> np.ndarray: - covars = meta_df[["SITE", "AGE", "GENDER"]] - categorical_cols = ["GENDER"] - continuous_cols = ["AGE"] - batch_col = "SITE" - combat = neuroCombat( - dat=X.T, - covars=covars, - batch_col=batch_col, - categorical_cols=categorical_cols, - continuous_cols=continuous_cols, - ) - harmonized_X = combat["data"].T - harmonized_X = np.clip(harmonized_X, -1, 1) - return harmonized_X - - -def split_kfoldcv_sbj( - y: np.ndarray, subjects: np.ndarray, num_fold: int, seed: int -): - unique_subjects, first_subject_index = np.unique( - subjects, return_index=True - ) - subject_y = y[first_subject_index] - subject_X = np.zeros_like(subject_y) - skfold = StratifiedKFold(n_splits=num_fold, random_state=seed, shuffle=True) - - result = [] - for train_subject_index, test_subject_index in skfold.split( - subject_X, subject_y - ): - train_subjects = unique_subjects[train_subject_index] - test_subjects = unique_subjects[test_subject_index] - train_index = np.argwhere(np.isin(subjects, train_subjects)).flatten() - test_index = np.argwhere(np.isin(subjects, test_subjects)).flatten() - assert len(np.intersect1d(train_index, test_index)) == 0 - assert ( - len(np.intersect1d(subjects[train_index], subjects[test_index])) - == 0 - ) - result.append((train_index, test_index)) - return result - - -def make_dataset( - x: np.ndarray, - y: np.ndarray, - d: np.ndarray, - age: np.ndarray, - gender: np.ndarray, -) -> Data: - graph = Data() - graph.x = torch.tensor(x).float() - graph.y = torch.tensor(y) - graph.d = torch.tensor(d) - graph.age = torch.tensor(age).float() - graph.gender = torch.tensor(gender).float() - return graph - - -class Dataset: - """ - required columns in meta.csv - - SUBJECT - - AGE - - GENDER - - SITE - - DISEASE_LABEL - - FC_MATRIX_PATH - """ - - def __init__(self, data_csv_path: str, name: str, harmonize: bool = False): - self.data_csv_path = os.path.abspath(data_csv_path) - self.data_folder = os.path.dirname(self.data_csv_path) - self.name = name + +class Dataset(Enum): + Schiz = "Schiz" + + +class DataloaderBase(ABC): + def __init__(self, dataset: Dataset, harmonize: bool = False): + self.dataset = dataset self.harmonize = harmonize - self._init_properties_() - - def _init_properties_(self): - meta_df = pd.read_csv(self.data_csv_path) - X = np.array( - [ - np.load(os.path.join(self.data_folder, path)) - for path in meta_df["FC_MATRIX_PATH"] - ] - ) - X = corr_mx_flatten(np.nan_to_num(X)) - if self.harmonize: - self.X = combat_harmonization(X, meta_df) + self._init_dataset_() + + def _init_dataset_(self) -> Data: + if self.dataset == Dataset.Schiz: + from Schiz import load_data_fmri, get_ages_and_genders, get_sites else: - self.X = X + raise NotImplementedError + + data: Tuple[np.ndarray] = load_data_fmri(harmonized=self.harmonize) + self.X: np.ndarray = data[0] + self.Y: np.ndarray = data[1].argmax(axis=1) + self.X_flattened: np.ndarray = corr_mx_flatten(self.X) - self.subjects = meta_df["SUBJECT"].values - self.ages = meta_df["AGE"].values - self.genders = meta_df["GENDER"].values - self.sites = meta_df["SITE"].values - self.y = meta_df["DISEASE_LABEL"].values + age_gender: Tuple[np.ndarray, np.ndarray] = get_ages_and_genders() + age, gender = age_gender + + mean_age = np.nanmean(age) + age = np.where(np.isnan(age), mean_age, age) + age = np.expand_dims(age, axis=1) + + assert np.all(np.isnan(gender) | (gender >= 0) | (gender <= 1)) + gender = np.where(np.isnan(gender), np.nanmean(gender), gender) + gender = np.expand_dims(gender, axis=1) + + self.age: np.ndarray = age + self.gender: np.ndarray = gender + self.sites: np.ndarray = get_sites() def _get_indices( self, seed: int = 0, - num_fold: int = 5, + fold: int = 0, ssl: bool = False, + validation: bool = False, labeled_sites: Optional[Union[str, Sequence[str]]] = None, - ) -> Generator[Dict[str, np.ndarray], None, None]: - - if labeled_sites is None: - is_labeled = np.ones(len(self.sites), dtype=bool) + unlabeled_sites: Optional[Union[str, Sequence[str]]] = None, + num_unlabeled: Optional[int] = None, + ) -> Dict[str, np.ndarray]: + if self.dataset == Dataset.Schiz: + from Schiz import get_splits else: - if isinstance(labeled_sites, str): - labeled_sites = [labeled_sites] - is_labeled = np.isin(self.sites, labeled_sites) - - unlabeled_indices = np.argwhere(~is_labeled).flatten() - labeled_indices = np.argwhere(is_labeled).flatten() - for train, test in split_kfoldcv_sbj( - self.y[is_labeled], self.subjects[is_labeled], num_fold, seed - ): - result = { - "labeled_train": labeled_indices[train], - "test": labeled_indices[test], - } - if ssl: - result["unlabeled_train"] = unlabeled_indices - yield result + raise NotImplementedError + + indices_list = defaultdict(list) + if labeled_sites is None or isinstance(labeled_sites, str): + labeled_sites = [labeled_sites] + for site_id in labeled_sites: + splits = get_splits(site_id, test=validation) + if validation: + test_indices = splits[seed][0] + labeled_train_indices, val_indices = splits[seed][1][fold] + indices_list["labeled_train"].append(labeled_train_indices) + indices_list["valid"].append(val_indices) + indices_list["test"].append(test_indices) + else: + labeled_train_indices, test_indices = splits[seed][1][fold] + indices_list["labeled_train"].append(labeled_train_indices) + indices_list["test"].append(test_indices) + + indices = dict() + for k, v in indices_list.items(): + if len(v) == 1: + indices[k] = v[0] + else: + indices[k] = np.concatenate(v, axis=0) + + if ssl: + if isinstance(unlabeled_sites, str): + unlabeled_sites = [unlabeled_sites] + unlabeled_indices = np.arange(len(self.X)) + if unlabeled_sites is not None: + unlabeled_indices = unlabeled_indices[ + np.isin(self.sites, unlabeled_sites) + ] + for idx in indices.values(): + unlabeled_indices = np.setdiff1d(unlabeled_indices, idx) + if ( + num_unlabeled is not None + and len(unlabeled_indices) > num_unlabeled + ): + unlabeled_indices = np.random.choice( + unlabeled_indices, num_unlabeled + ) + indices["unlabeled_train"] = unlabeled_indices + + keys = list(indices.keys()) + for i in range(len(keys)): + for j in range(i + 1, len(keys)): + assert ( + np.intersect1d(indices[keys[i]], indices[keys[j]]).size == 0 + ) + return indices + @abstractmethod def load_split_data( self, seed: int = 0, - num_fold: int = 5, + fold: int = 0, ssl: bool = False, + validation: bool = False, labeled_sites: Optional[Union[str, Sequence[str]]] = None, - ) -> Generator[Dict[str, Union[int, Data]], None, None]: - for indices in self._get_indices(seed, num_fold, ssl, labeled_sites): - if ssl: - all_train_indices = np.concatenate( - (indices["labeled_train"], indices["unlabeled_train"]) - ) - else: - all_train_indices = indices["labeled_train"] - le = LabelEncoder() - le.fit(self.sites[all_train_indices]) - - all_data: Dict[str, Data] = dict() - for name, idx in indices.items(): - all_data[name] = make_dataset( - x=self.X[idx], - y=self.y[idx], - d=le.transform(self.sites[idx]), - age=self.ages[idx], - gender=self.genders[idx], - ) + unlabeled_sites: Optional[Union[str, Sequence[str]]] = None, + num_unlabeled: Optional[int] = None, + num_process: int = 1, + ) -> Union[ + Dict[str, Union[int, Data]], Dict[str, Union[int, Sequence[Data]]] + ]: + raise NotImplementedError + + @abstractmethod + def load_all_data( + self, + sites: Optional[Union[str, Sequence[str]]] = None, + num_process: int = 1, + ) -> Union[ + Dict[str, Union[int, Data]], Dict[str, Union[int, Sequence[Data]]] + ]: + raise NotImplementedError + + +class ModelBaseDataloader(DataloaderBase): + @staticmethod + def make_dataset( + x: np.ndarray, + y: np.ndarray, + d: np.ndarray, + age: np.ndarray, + gender: np.ndarray, + ) -> Data: + graph = Data() + graph.x = torch.tensor(x).float() + graph.y = torch.tensor(y) + graph.d = torch.tensor(d) + graph.age = torch.tensor(age).float() + graph.gender = torch.tensor(gender).float() + return graph - all_data["input_size"] = int(self.X.shape[1]) - all_data["num_sites"] = int(len(le.classes_)) + def load_split_data( + self, + seed: int = 0, + fold: int = 0, + ssl: bool = False, + validation: bool = False, + labeled_sites: Optional[Union[str, Sequence[str]]] = None, + unlabeled_sites: Optional[Union[str, Sequence[str]]] = None, + num_unlabeled: Optional[int] = None, + num_process: int = 1, + ) -> Dict[str, Union[int, Data]]: + indices = self._get_indices( + seed, + fold, + ssl, + validation, + labeled_sites, + unlabeled_sites, + num_unlabeled, + ) - empty = Data(x=torch.tensor([])) - all_data["num_labeled_train"] = all_data.get( - "labeled_train", empty - ).x.size(0) - all_data["num_unlabeled_train"] = all_data.get( - "unlabeled_train", empty - ).x.size(0) - all_data["num_test"] = all_data.get("test", empty).x.size(0) - yield all_data + if ssl: + all_train_indices = np.concatenate( + (indices["labeled_train"], indices["unlabeled_train"]) + ) + else: + all_train_indices = indices["labeled_train"] + le = LabelEncoder() + le.fit(self.sites[all_train_indices]) + + all_data: Dict[str, Data] = dict() + for name, idx in indices.items(): + all_data[name] = self.make_dataset( + x=self.X_flattened[idx], + y=self.Y[idx], + d=le.transform(self.sites[idx]), + age=self.age[idx], + gender=self.gender[idx], + ) + + all_data["input_size"] = int(self.X_flattened.shape[1]) + all_data["num_sites"] = int(len(le.classes_)) + + empty = Data(x=torch.tensor([])) + all_data["num_labeled_train"] = all_data.get( + "labeled_train", empty + ).x.size(0) + all_data["num_unlabeled_train"] = all_data.get( + "unlabeled_train", empty + ).x.size(0) + all_data["num_valid"] = all_data.get("valid", empty).x.size(0) + all_data["num_test"] = all_data.get("test", empty).x.size(0) + return all_data def load_all_data( - self, sites: Optional[Union[str, Sequence[str]]] = None, + self, + sites: Optional[Union[str, Sequence[str]]] = None, + num_process: int = 1, ) -> Dict[str, Union[int, Data]]: if isinstance(sites, str): sites = [sites] @@ -209,13 +226,13 @@ def load_all_data( le.fit(self.sites[all_indices]) return { - "data": make_dataset( - x=self.X[all_indices], - y=self.y[all_indices], + "data": self.make_dataset( + x=self.X_flattened[all_indices], + y=self.Y[all_indices], d=le.transform(self.sites[all_indices]), - age=self.ages[all_indices], - gender=self.genders[all_indices], + age=self.age[all_indices], + gender=self.gender[all_indices], ), - "input_size": int(self.X.shape[1]), + "input_size": int(self.X_flattened.shape[1]), "num_sites": int(len(le.classes_)), } diff --git a/src/evaluate.py b/src/evaluate.py deleted file mode 100644 index 73fc963..0000000 --- a/src/evaluate.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch -from typing import Any, Dict -from torch_geometric.data import Data - -from data import Dataset -from utils.metrics import ClassificationMetrics as CM -from models.EDC_VAE import EDC_VAE - - -def load_model(model_params: Dict[str, Any], model_path: str): - model = EDC_VAE(**model_params) - model.load_state_dict( - torch.load(model_path, map_location=torch.device("cpu")) - ) - return model - - -def load_data(data_csv_path: str, harmonize: bool = False): - dataset = Dataset(data_csv_path, "", harmonize) - return dataset.load_all_data()["data"] - - -def evaluate_model(model: EDC_VAE, data: Data): - x, y = data.x, data.y - prediction = model.forward(x) - print("accuracy: {:.5f}".format(CM.accuracy(y, prediction["y"]).item())) - print("f1: {:.5f}".format(CM.f1_score(y, prediction["y"]).item())) - print("recall: {:.5f}".format(CM.tpr(y, prediction["y"]).item())) - print("precision: {:.5f}".format(CM.ppv(y, prediction["y"]).item())) - - -if __name__ == "__main__": - model = load_model( - dict( - input_size=34716, - hidden_size=32, - emb_size=16, - clf_hidden_1=0, - clf_hidden_2=0, - ), - model_path="../saved_model/ABIDE_VAE-FFN_0_0_1645419832.pt" - ) - data = load_data("../dataset/ABIDE/meta.csv") - evaluate_model(model, data) \ No newline at end of file diff --git a/src/factory.py b/src/factory.py new file mode 100644 index 0000000..ec5cce8 --- /dev/null +++ b/src/factory.py @@ -0,0 +1,80 @@ +from abc import ABC, abstractclassmethod +from dataclasses import dataclass + +import os +import sys +from typing import Any, Dict, Tuple, Type, Union + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from models.base import ModelBase +from models import ( + FFN, + VAE_FFN, + SHRED, + SHRED_I, + SHRED_III, +) + +from data import DataloaderBase, ModelBaseDataloader + + +class FrameworkFactory(ABC): + @abstractclassmethod + def load_model( + cls, model_name: str, model_param: Dict[str, Any] + ) -> Union[ModelBase, Tuple[ModelBase, ModelBase]]: + raise NotImplementedError + + @abstractclassmethod + def load_dataloader( + cls, model_name: str, dataloader_param: Dict[str, Any] + ) -> DataloaderBase: + raise NotImplementedError + + +class SingleStageFrameworkFactory(FrameworkFactory): + @dataclass + class Mapping: + model_cls: Type[ModelBase] + dataloader_cls: Type[DataloaderBase] + + mapping = { + "FFN": Mapping(FFN, ModelBaseDataloader), + "VAE-FFN": Mapping(VAE_FFN, ModelBaseDataloader), + "SHRED": Mapping(SHRED, ModelBaseDataloader), + "SHRED-I": Mapping(SHRED_I, ModelBaseDataloader), + "SHRED-III": Mapping(SHRED_III, ModelBaseDataloader), + } + + @classmethod + def get_model_class(cls, model_name: str) -> ModelBase: + model_mapping = cls.mapping.get(model_name, None) + if model_mapping is None: + raise NotImplementedError( + "Model {} does not exist".format(model_name) + ) + return model_mapping.model_cls + + @classmethod + def load_model( + cls, model_name: str, model_param: Dict[str, Any] + ) -> ModelBase: + model_mapping = cls.mapping.get(model_name, None) + if model_mapping is None: + raise NotImplementedError( + "Model {} does not exist".format(model_name) + ) + return model_mapping.model_cls(**model_param) + + @classmethod + def load_dataloader( + cls, model_name: str, dataloader_param: Dict[str, Any] + ) -> DataloaderBase: + model_mapping = cls.mapping.get(model_name, None) + if model_mapping is None: + raise NotImplementedError( + "Model {} does not exist".format(model_name) + ) + return model_mapping.dataloader_cls(**dataloader_param) + diff --git a/src/main.py b/src/main.py deleted file mode 100644 index d269fe6..0000000 --- a/src/main.py +++ /dev/null @@ -1,83 +0,0 @@ -import os -import json -import time -import yaml -import logging -import argparse -import pandas as pd -from typing import Any, Dict - -from utils.misc import mkdir, seed_torch -from trainer import EDC_VAE_Trainer, TrainerParams - - -def process(config: Dict[str, Any]): - seed_torch() - logging.info("CONFIG:\n{}".format(json.dumps(config, indent=4))) - - script_name = os.path.splitext(os.path.basename(__file__))[0] - experiment_name = "{}_{}_{}".format( - script_name, int(time.time()), os.getpid() - ) - - output_dir = config.get("output_directory") - output_dir = os.path.abspath(os.path.join(output_dir, experiment_name)) - - config_path = os.path.join( - output_dir, "{}.config.json".format(experiment_name), - ) - results_csv_path = os.path.join( - output_dir, "{}.csv".format(experiment_name), - ) - - mkdir(output_dir) - with open(config_path, "w") as f: - json.dump(config, f, indent=4, sort_keys=True) - - trainer = EDC_VAE_Trainer( - trainer_params=TrainerParams( - output_directory=output_dir, - model_name=config.get("model_name"), - model_params=config.get("model_params", dict()), - optim_params=config.get("optim_params", dict()), - hyperparameters=config.get("hyperparameters", dict()), - dataset_path=config.get("dataset_path"), - dataset_name=config.get("dataset_name"), - seeds=config.get("seeds", list(range(10))), - num_fold=config.get("num_fold", 5), - ssl=config.get("ssl", False), - harmonize=config.get("harmonize", False), - labeled_sites=config.get("labeled_sites", None), - device=config.get("device", -1), - verbose=config.get("verbose", False), - patience=config.get("patience", float("inf")), - max_epoch=config.get("max_epoch", 1000), - save_model=config.get("save_model", False), - ), - ) - trainer.run(results_csv_path) - - -def main(args): - with open(os.path.abspath(args.config), "r") as f: - config: Dict[str, Any] = yaml.full_load(f) - process(config) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--config", - type=str, - default="./config.yml", - help="the path to the config file", - ) - args = parser.parse_args() - - logging.basicConfig( - level=logging.DEBUG, - format="[%(asctime)s] - %(filename)s: %(levelname)s: " - "%(funcName)s(): %(lineno)d:\t%(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - main(args) diff --git a/src/models/FFN.py b/src/models/FFN.py new file mode 100644 index 0000000..269f45b --- /dev/null +++ b/src/models/FFN.py @@ -0,0 +1,122 @@ +import os +import sys +import torch +import torch.nn.functional as F +from typing import Any, Optional, Dict +from torch.optim import Optimizer +from torch_geometric.data import Data + +__dir__ = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(__dir__) + +from utils.metrics import ClassificationMetrics as CM +from models.base import LatentSpaceEncoding, ModelBase, FeedForward + + +class FFN(ModelBase, LatentSpaceEncoding): + def __init__( + self, + input_size: int, + hidden_1: int, + hidden_2: int, + hidden_3: int, + output_size: int = 2, + dropout: float = 0.5, + **kwargs + ): + super().__init__() + self.classifier = FeedForward( + input_size, + [h for h in [hidden_1, hidden_2, hidden_3] if h > 0], + output_size, + dropout=dropout, + ) + + def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: + y = self.classifier(x) + return {"y": y} + + def ss_forward(self, x: torch.Tensor) -> torch.Tensor: + return self.forward(x)["y"] + + def ls_forward(self, data: Data) -> torch.Tensor: + raise NotImplementedError + + def is_forward(self, data: Data) -> torch.Tensor: + return data.x + + def get_surface(self, z: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def get_input_surface(self, x: torch.Tensor) -> torch.Tensor: + return self.ss_forward(x) + + def train_step( + self, + device: torch.device, + labeled_data: Data, + unlabeled_data: Optional[Data], + optimizer: Optimizer, + hyperparameters: Dict[str, Any], + ) -> Dict[str, float]: + self.to(device) + self.train() + + x: torch.Tensor = labeled_data.x + real_y: torch.Tensor = labeled_data.y + x, real_y = ( + x.to(device), + real_y.to(device), + ) + + with torch.enable_grad(): + optimizer.zero_grad() + pred_y = self.ss_forward(x) + ce_loss = F.cross_entropy(pred_y, real_y) + ce_loss.backward() + optimizer.step() + + accuracy = CM.accuracy(real_y, pred_y) + sensitivity = CM.tpr(real_y, pred_y) + specificity = CM.tnr(real_y, pred_y) + precision = CM.ppv(real_y, pred_y) + f1_score = CM.f1_score(real_y, pred_y) + metrics = { + "ce_loss": ce_loss.item(), + "accuracy": accuracy.item(), + "sensitivity": sensitivity.item(), + "specificity": specificity.item(), + "f1": f1_score.item(), + "precision": precision.item(), + } + return metrics + + def test_step( + self, device: torch.device, test_data: Data + ) -> Dict[str, float]: + self.to(device) + self.eval() + + with torch.no_grad(): + x: torch.Tensor = test_data.x + real_y: torch.Tensor = test_data.y + x, real_y = x.to(device), real_y.to(device) + + pred_y = self.ss_forward(x) + ce_loss = F.cross_entropy(pred_y, real_y) + + accuracy = CM.accuracy(real_y, pred_y) + sensitivity = CM.tpr(real_y, pred_y) + specificity = CM.tnr(real_y, pred_y) + precision = CM.ppv(real_y, pred_y) + f1_score = CM.f1_score(real_y, pred_y) + + metrics = { + "ce_loss": ce_loss.item(), + "accuracy": accuracy.item(), + "sensitivity": sensitivity.item(), + "specificity": specificity.item(), + "f1": f1_score.item(), + "precision": precision.item(), + } + return metrics diff --git a/src/models/SHRED.py b/src/models/SHRED.py new file mode 100644 index 0000000..7e81275 --- /dev/null +++ b/src/models/SHRED.py @@ -0,0 +1,393 @@ +import os +import sys +import torch +import torch.nn.functional as F +from typing import Any, Dict, Optional, OrderedDict, Tuple +from torch.nn import Module, Linear, Parameter, BatchNorm1d +from torch.optim import Optimizer +from torch_geometric.data import Data + +__dir__ = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(__dir__) + +from utils.loss import kl_divergence_loss +from utils.metrics import ClassificationMetrics as CM +from models.base import LatentSpaceEncoding, ModelBase +from models.VAE_FFN import VAE_FFN + + +def init_zero(linear_layer: Linear): + linear_layer.weight.data.fill_(0.0) + linear_layer.bias.data.fill_(0.0) + + +class CH(Module): + def __init__(self, input_size: int, num_sites: int): + super().__init__() + self.num_sites = num_sites + self.alpha = Parameter(torch.zeros(input_size)) + self.age_norm = BatchNorm1d(1) + self.age = Linear(1, input_size) + self.gender = Linear(1, input_size) + self.gamma = Linear(num_sites, input_size) + self.delta = Linear(num_sites, input_size) + init_zero(self.age) + init_zero(self.gender) + init_zero(self.gamma) + init_zero(self.delta) + + def forward( + self, + x: torch.Tensor, + age: torch.Tensor, + gender: torch.Tensor, + site: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + if site.ndim == 2: + pass + elif site.ndim == 1: + site = F.one_hot(site, num_classes=self.num_sites).float() + else: + raise ValueError("invalid shape for site: {}".format(site.size())) + + age_x = self.age(self.age_norm(age)) + gender_x = self.gender(gender) + gamma = self.gamma(site) + delta = torch.exp(self.delta(site)) + eps = (x - self.alpha - age_x - gender_x - gamma) / delta + x_ch = self.alpha + eps + return { + "x_ch": x_ch, + "alpha": self.alpha, + "age": age_x, + "gender": gender_x, + "eps": eps, + "gamma": gamma, + "delta": delta, + } + + def inverse( + self, + x_ch: torch.Tensor, + age_x: torch.Tensor, + gender_x: torch.Tensor, + gamma: torch.Tensor, + delta: torch.Tensor, + ) -> torch.Tensor: + eps = x_ch - self.alpha + return self.inverse_eps(eps, age_x, gender_x, gamma, delta) + + def inverse_eps( + self, + eps: torch.Tensor, + age_x: torch.Tensor, + gender_x: torch.Tensor, + gamma, + delta: torch.Tensor, + ) -> torch.Tensor: + x = self.alpha + age_x + gender_x + gamma + delta * eps + return x + + +class SHRED(ModelBase, LatentSpaceEncoding): + def __init__( + self, + num_sites: int, + input_size: int, + hidden_size: int, + emb_size: int, + clf_hidden_1: int, + clf_hidden_2: int, + clf_output_size: int = 2, + dropout: float = 0.25, + **kwargs + ): + super().__init__() + self.ch = CH(input_size, num_sites) + self.vae_ffn = VAE_FFN( + input_size, + hidden_size, + emb_size, + clf_hidden_1, + clf_hidden_2, + clf_output_size, + dropout=dropout, + ) + + def combat( + self, + x: torch.Tensor, + age: torch.Tensor, + gender: torch.Tensor, + site: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + return self.ch(x, age, gender, site) + + def forward( + self, + x: torch.Tensor, + age: torch.Tensor, + gender: torch.Tensor, + site: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + ch_res = self.ch(x, age, gender, site) + vae_res = self.vae_ffn(ch_res["eps"]) + + eps_mu = vae_res["x_mu"] + vae_res["eps_mu"] = eps_mu + vae_res["x_mu"] = self.ch.inverse_eps( + eps_mu, + ch_res["age"], + ch_res["gender"], + ch_res["gamma"], + ch_res["delta"], + ) + return {**ch_res, **vae_res} + + def ss_forward(self, eps: torch.Tensor) -> torch.Tensor: + y = self.vae_ffn.ss_forward(eps) + return y + + def get_baselines_inputs( + self, data: Data + ) -> Tuple[torch.Tensor, torch.Tensor]: + x, y = data.x, data.y + age, gender, site = data.age, data.gender, data.d + ch_res = self.combat(x, age, gender, site) + eps: torch.Tensor = ch_res["eps"] + baselines = eps[y == 0].mean(dim=0).view(1, -1) + inputs = eps[y == 1] + return baselines, inputs + + def ls_forward(self, data: Data) -> torch.Tensor: + x, age, gender, site = data.x, data.age, data.gender, data.d + ch_res = self.combat(x, age, gender, site) + eps: torch.Tensor = ch_res["eps"] + z_mu, _ = self.vae_ffn.encode(eps) + return z_mu + + def is_forward(self, data: Data) -> torch.Tensor: + x, age, gender, site = data.x, data.age, data.gender, data.d + ch_res = self.combat(x, age, gender, site) + eps: torch.Tensor = ch_res["eps"] + return eps + + def get_surface(self, z: torch.Tensor) -> torch.Tensor: + return self.vae_ffn.get_surface(z) + + def get_input_surface(self, x: torch.Tensor) -> torch.Tensor: + z_mu, _ = self.vae_ffn.encode(x) + return self.vae_ffn.get_surface(z_mu) + + def train_step( + self, + device: torch.device, + labeled_data: Data, + unlabeled_data: Optional[Data], + optimizer: Optimizer, + hyperparameters: Dict[str, Any], + ) -> Dict[str, float]: + self.to(device) + self.train() + + labeled_x: torch.Tensor = labeled_data.x + labeled_age: torch.Tensor = labeled_data.age + labeled_gender: torch.Tensor = labeled_data.gender + labeled_site: torch.Tensor = labeled_data.d + real_y: torch.Tensor = labeled_data.y + labeled_x, labeled_age, labeled_gender, labeled_site, real_y = ( + labeled_x.to(device), + labeled_age.to(device), + labeled_gender.to(device), + labeled_site.to(device), + real_y.to(device), + ) + + if unlabeled_data is not None: + unlabeled_x: torch.Tensor = unlabeled_data.x + unlabeled_age: torch.Tensor = unlabeled_data.age + unlabeled_gender: torch.Tensor = unlabeled_data.gender + unlabeled_site: torch.Tensor = unlabeled_data.d + unlabeled_x, unlabeled_age, unlabeled_gender, unlabeled_site = ( + unlabeled_x.to(device), + unlabeled_age.to(device), + unlabeled_gender.to(device), + unlabeled_site.to(device), + ) + + with torch.enable_grad(): + optimizer.zero_grad() + + labeled_res = self( + labeled_x, labeled_age, labeled_gender, labeled_site + ) + alpha: torch.Tensor = labeled_res["alpha"] + labeled_age_x = labeled_res["age"] + labeled_gender_x = labeled_res["gender"] + pred_y = labeled_res["y"] + labeled_x_mu = labeled_res["x_mu"] + labeled_x_std = labeled_res["x_std"] + labeled_z_mu = labeled_res["z_mu"] + labeled_z_std = labeled_res["z_std"] + labeled_eps = labeled_res["eps"] + if unlabeled_data is not None: + unlabeled_res = self( + unlabeled_x, unlabeled_age, unlabeled_gender, unlabeled_site + ) + unlabeled_age_x = unlabeled_res["age"] + unlabeled_gender_x = unlabeled_res["gender"] + unlabeled_x_mu = unlabeled_res["x_mu"] + unlabeled_x_std = unlabeled_res["x_std"] + unlabeled_z_mu = unlabeled_res["z_mu"] + unlabeled_z_std = unlabeled_res["z_std"] + unlabeled_eps = unlabeled_res["eps"] + age_x = torch.cat((labeled_age_x, unlabeled_age_x), dim=0) + gender_x = torch.cat( + (labeled_gender_x, unlabeled_gender_x), dim=0 + ) + x = torch.cat((labeled_x, unlabeled_x), dim=0) + x_mu = torch.cat((labeled_x_mu, unlabeled_x_mu), dim=0) + x_std = torch.cat((labeled_x_std, unlabeled_x_std), dim=0) + z_mu = torch.cat((labeled_z_mu, unlabeled_z_mu), dim=0) + z_std = torch.cat((labeled_z_std, unlabeled_z_std), dim=0) + eps = torch.cat((labeled_eps, unlabeled_eps), dim=0) + else: + age_x = labeled_age_x + gender_x = labeled_gender_x + x = labeled_x + x_mu = labeled_x_mu + x_std = labeled_x_std + z_mu = labeled_z_mu + z_std = labeled_z_std + eps = labeled_eps + + ce_loss = F.cross_entropy(pred_y, real_y) + rc_loss = F.gaussian_nll_loss(x_mu, x, x_std ** 2, full=True) + kl_loss = kl_divergence_loss( + z_mu, + z_std ** 2, + torch.zeros_like(z_mu), + torch.ones_like(z_std), + ) + + ch_loss = (eps.mean(dim=0) ** 2).sum() + alpha_loss = ( + F.mse_loss( + alpha.expand(age_x.size()), + x, + reduction="none", + ) + .sum(dim=1) + .mean() + ) + + gamma1 = hyperparameters.get("rc_loss", 1) + gamma2 = hyperparameters.get("kl_loss", 1) + gamma3 = hyperparameters.get("ch_loss", 1) + use_alpha_loss = hyperparameters.get("alpha_loss", True) + + if use_alpha_loss: + total_loss = ( + ce_loss + + gamma1 * rc_loss + + gamma2 * kl_loss + + gamma3 * (ch_loss + alpha_loss) + ) + else: + total_loss = ( + ce_loss + + gamma1 * rc_loss + + gamma2 * kl_loss + + gamma3 * ch_loss + ) + total_loss.backward() + optimizer.step() + + accuracy = CM.accuracy(real_y, pred_y) + sensitivity = CM.tpr(real_y, pred_y) + specificity = CM.tnr(real_y, pred_y) + precision = CM.ppv(real_y, pred_y) + f1_score = CM.f1_score(real_y, pred_y) + metrics = { + "ce_loss": ce_loss.item(), + "rc_loss": rc_loss.item(), + "kl_loss": kl_loss.item(), + "ch_loss": ch_loss.item(), + "alpha_loss": alpha_loss.item(), + "accuracy": accuracy.item(), + "sensitivity": sensitivity.item(), + "specificity": specificity.item(), + "f1": f1_score.item(), + "precision": precision.item(), + } + return metrics + + def test_step( + self, device: torch.device, test_data: Data + ) -> Dict[str, float]: + self.to(device) + self.eval() + + with torch.no_grad(): + x: torch.Tensor = test_data.x + age: torch.Tensor = test_data.age + gender: torch.Tensor = test_data.gender + site: torch.Tensor = test_data.d + real_y: torch.Tensor = test_data.y + x, age, gender, site, real_y = ( + x.to(device), + age.to(device), + gender.to(device), + site.to(device), + real_y.to(device), + ) + + res = self(x, age, gender, site) + pred_y = res["y"] + x_mu = res["x_mu"] + x_std = res["x_std"] + z_mu = res["z_mu"] + z_std = res["z_std"] + alpha: torch.Tensor = res["alpha"] + age_x: torch.Tensor = res["age"] + gender_x = res["gender"] + eps: torch.Tensor = res["eps"] + + ce_loss = F.cross_entropy(pred_y, real_y) + rc_loss = F.gaussian_nll_loss(x_mu, x, x_std ** 2, full=True) + kl_loss = kl_divergence_loss( + z_mu, + z_std ** 2, + torch.zeros_like(z_mu), + torch.ones_like(z_std), + ) + + ch_loss = (eps.mean(dim=0) ** 2).sum() + alpha_loss = ( + F.mse_loss( + alpha.expand(age_x.size()), + x, + reduction="none", + ) + .sum(dim=1) + .mean() + ) + + accuracy = CM.accuracy(real_y, pred_y) + sensitivity = CM.tpr(real_y, pred_y) + specificity = CM.tnr(real_y, pred_y) + precision = CM.ppv(real_y, pred_y) + f1_score = CM.f1_score(real_y, pred_y) + metrics = { + "ce_loss": ce_loss.item(), + "rc_loss": rc_loss.item(), + "kl_loss": kl_loss.item(), + "ch_loss": ch_loss.item(), + "alpha_loss": alpha_loss.item(), + "accuracy": accuracy.item(), + "sensitivity": sensitivity.item(), + "specificity": specificity.item(), + "f1": f1_score.item(), + "precision": precision.item(), + } + return metrics diff --git a/src/models/SHRED_I.py b/src/models/SHRED_I.py new file mode 100644 index 0000000..718667a --- /dev/null +++ b/src/models/SHRED_I.py @@ -0,0 +1,282 @@ +import os +import sys +import torch +import torch.nn.functional as F +from torch.nn.utils import clip_grad_norm_ +from typing import Dict, Tuple, Optional, Any +from torch.optim import Optimizer +from torch_geometric.data import Data + +__dir__ = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(__dir__) + +from models.SHRED import SHRED +from utils.loss import kl_divergence_loss +from utils.metrics import ClassificationMetrics as CM + + +class SHRED_I(SHRED): + + def forward( + self, + x: torch.Tensor, + age: torch.Tensor, + gender: torch.Tensor, + site: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + + ch_res = self.ch(x, age, gender, site) + vae_res = self.vae_ffn(ch_res["x_ch"]) + + x_ch_mu = vae_res["x_mu"] + vae_res["x_ch_mu"] = x_ch_mu + vae_res["x_mu"] = self.ch.inverse( + x_ch_mu, + ch_res["age"], + ch_res["gender"], + ch_res["gamma"], + ch_res["delta"], + ) + return {**ch_res, **vae_res} + + def ss_forward(self, x_ch: torch.Tensor) -> torch.Tensor: + y = self.vae_ffn.ss_forward(x_ch) + return y + + def get_baselines_inputs( + self, data: Data + ) -> Tuple[torch.Tensor, torch.Tensor]: + x, y = data.x, data.y + age, gender, site = data.age, data.gender, data.d + ch_res = self.combat(x, age, gender, site) + x_ch: torch.Tensor = ch_res["x_ch"] + baselines = x_ch[y == 0].mean(dim=0).view(1, -1) + inputs = x_ch[y == 1] + return baselines, inputs + + def ls_forward(self, data: Data) -> torch.Tensor: + x, age, gender, site = data.x, data.age, data.gender, data.d + ch_res = self.combat(x, age, gender, site) + x_ch: torch.Tensor = ch_res["x_ch"] + z_mu, _ = self.vae_ffn.encode(x_ch) + return z_mu + + def is_forward(self, data: Data) -> torch.Tensor: + x, age, gender, site = data.x, data.age, data.gender, data.d + ch_res = self.combat(x, age, gender, site) + x_ch: torch.Tensor = ch_res["x_ch"] + return x_ch + + def train_step( + self, + device: torch.device, + labeled_data: Data, + unlabeled_data: Optional[Data], + optimizer: Optimizer, + hyperparameters: Dict[str, Any], + ) -> Dict[str, float]: + self.to(device) + self.train() + + labeled_x: torch.Tensor = labeled_data.x + labeled_age: torch.Tensor = labeled_data.age + labeled_gender: torch.Tensor = labeled_data.gender + labeled_site: torch.Tensor = labeled_data.d + real_y: torch.Tensor = labeled_data.y + labeled_x, labeled_age, labeled_gender, labeled_site, real_y = ( + labeled_x.to(device), + labeled_age.to(device), + labeled_gender.to(device), + labeled_site.to(device), + real_y.to(device), + ) + + if unlabeled_data is not None: + unlabeled_x: torch.Tensor = unlabeled_data.x + unlabeled_age: torch.Tensor = unlabeled_data.age + unlabeled_gender: torch.Tensor = unlabeled_data.gender + unlabeled_site: torch.Tensor = unlabeled_data.d + unlabeled_x, unlabeled_age, unlabeled_gender, unlabeled_site = ( + unlabeled_x.to(device), + unlabeled_age.to(device), + unlabeled_gender.to(device), + unlabeled_site.to(device), + ) + + with torch.enable_grad(): + optimizer.zero_grad() + + labeled_res = self( + labeled_x, labeled_age, labeled_gender, labeled_site + ) + alpha: torch.Tensor = labeled_res["alpha"] + labeled_age_x = labeled_res["age"] + labeled_gender_x = labeled_res["gender"] + pred_y = labeled_res["y"] + labeled_x_mu = labeled_res["x_mu"] + labeled_x_std = labeled_res["x_std"] + labeled_z_mu = labeled_res["z_mu"] + labeled_z_std = labeled_res["z_std"] + labeled_eps = labeled_res["eps"] + if unlabeled_data is not None: + unlabeled_res = self( + unlabeled_x, unlabeled_age, unlabeled_gender, unlabeled_site + ) + unlabeled_age_x = unlabeled_res["age"] + unlabeled_gender_x = unlabeled_res["gender"] + unlabeled_x_mu = unlabeled_res["x_mu"] + unlabeled_x_std = unlabeled_res["x_std"] + unlabeled_z_mu = unlabeled_res["z_mu"] + unlabeled_z_std = unlabeled_res["z_std"] + unlabeled_eps = unlabeled_res["eps"] + age_x = torch.cat((labeled_age_x, unlabeled_age_x), dim=0) + gender_x = torch.cat( + (labeled_gender_x, unlabeled_gender_x), dim=0 + ) + x = torch.cat((labeled_x, unlabeled_x), dim=0) + x_mu = torch.cat((labeled_x_mu, unlabeled_x_mu), dim=0) + x_std = torch.cat((labeled_x_std, unlabeled_x_std), dim=0) + z_mu = torch.cat((labeled_z_mu, unlabeled_z_mu), dim=0) + z_std = torch.cat((labeled_z_std, unlabeled_z_std), dim=0) + eps = torch.cat((labeled_eps, unlabeled_eps), dim=0) + else: + age_x = labeled_age_x + gender_x = labeled_gender_x + x = labeled_x + x_mu = labeled_x_mu + x_std = labeled_x_std + z_mu = labeled_z_mu + z_std = labeled_z_std + eps = labeled_eps + + ce_loss = F.cross_entropy(pred_y, real_y) + rc_loss = F.gaussian_nll_loss(x_mu, x, x_std ** 2, full=True) + kl_loss = kl_divergence_loss( + z_mu, + z_std ** 2, + torch.zeros_like(z_mu), + torch.ones_like(z_std), + ) + ch_loss = (eps ** 2).sum(dim=1).mean() + + alpha_loss = ( + F.mse_loss( + alpha.expand(age_x.size()), + x, + reduction="none", + ) + .sum(dim=1) + .mean() + ) + + gamma1 = hyperparameters.get("rc_loss", 1) + gamma2 = hyperparameters.get("kl_loss", 1) + gamma3 = hyperparameters.get("ch_loss", 1) + use_alpha_loss = hyperparameters.get("alpha_loss", True) + + if use_alpha_loss: + total_loss = ( + ce_loss + + gamma1 * rc_loss + + gamma2 * kl_loss + + gamma3 * (ch_loss + alpha_loss) + ) + else: + total_loss = ( + ce_loss + + gamma1 * rc_loss + + gamma2 * kl_loss + + gamma3 * ch_loss + ) + total_loss.backward() + clip_grad_norm_(self.parameters(), 0.1) + optimizer.step() + + accuracy = CM.accuracy(real_y, pred_y) + sensitivity = CM.tpr(real_y, pred_y) + specificity = CM.tnr(real_y, pred_y) + precision = CM.ppv(real_y, pred_y) + f1_score = CM.f1_score(real_y, pred_y) + metrics = { + "ce_loss": ce_loss.item(), + "rc_loss": rc_loss.item(), + "kl_loss": kl_loss.item(), + "ch_loss": ch_loss.item(), + "alpha_loss": alpha_loss.item(), + "accuracy": accuracy.item(), + "sensitivity": sensitivity.item(), + "specificity": specificity.item(), + "f1": f1_score.item(), + "precision": precision.item(), + } + return metrics + + def test_step( + self, device: torch.device, test_data: Data + ) -> Dict[str, float]: + self.to(device) + self.eval() + + with torch.no_grad(): + x: torch.Tensor = test_data.x + age: torch.Tensor = test_data.age + gender: torch.Tensor = test_data.gender + site: torch.Tensor = test_data.d + real_y: torch.Tensor = test_data.y + x, age, gender, site, real_y = ( + x.to(device), + age.to(device), + gender.to(device), + site.to(device), + real_y.to(device), + ) + + res = self(x, age, gender, site) + pred_y = res["y"] + x_mu = res["x_mu"] + x_std = res["x_std"] + z_mu = res["z_mu"] + z_std = res["z_std"] + alpha: torch.Tensor = res["alpha"] + age_x: torch.Tensor = res["age"] + gender_x = res["gender"] + eps: torch.Tensor = res["eps"] + + ce_loss = F.cross_entropy(pred_y, real_y) + rc_loss = F.gaussian_nll_loss(x_mu, x, x_std ** 2, full=True) + kl_loss = kl_divergence_loss( + z_mu, + z_std ** 2, + torch.zeros_like(z_mu), + torch.ones_like(z_std), + ) + ch_loss = (eps ** 2).sum(dim=1).mean() + + alpha_loss = ( + F.mse_loss( + alpha.expand(age_x.size()), + x, + reduction="none", + ) + .sum(dim=1) + .mean() + ) + + accuracy = CM.accuracy(real_y, pred_y) + sensitivity = CM.tpr(real_y, pred_y) + specificity = CM.tnr(real_y, pred_y) + precision = CM.ppv(real_y, pred_y) + f1_score = CM.f1_score(real_y, pred_y) + metrics = { + "ce_loss": ce_loss.item(), + "rc_loss": rc_loss.item(), + "kl_loss": kl_loss.item(), + "ch_loss": ch_loss.item(), + "alpha_loss": alpha_loss.item(), + "accuracy": accuracy.item(), + "sensitivity": sensitivity.item(), + "specificity": specificity.item(), + "f1": f1_score.item(), + "precision": precision.item(), + } + return metrics diff --git a/src/models/SHRED_III.py b/src/models/SHRED_III.py new file mode 100644 index 0000000..786379f --- /dev/null +++ b/src/models/SHRED_III.py @@ -0,0 +1,241 @@ +import os +import sys +import torch +import torch.nn.functional as F +from torch.nn.utils import clip_grad_norm_ +from typing import Any, Dict, Optional +from torch.optim import Optimizer +from torch_geometric.data import Data + +__dir__ = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(__dir__) + +from utils.loss import kl_divergence_loss +from utils.metrics import ClassificationMetrics as CM +from models.SHRED_I import SHRED_I + + +class SHRED_III(SHRED_I): + def train_step( + self, + device: torch.device, + labeled_data: Data, + unlabeled_data: Optional[Data], + optimizer: Optimizer, + hyperparameters: Dict[str, Any], + ) -> Dict[str, float]: + self.to(device) + self.train() + + labeled_x: torch.Tensor = labeled_data.x + labeled_age: torch.Tensor = labeled_data.age + labeled_gender: torch.Tensor = labeled_data.gender + labeled_site: torch.Tensor = labeled_data.d + real_y: torch.Tensor = labeled_data.y + labeled_x, labeled_age, labeled_gender, labeled_site, real_y = ( + labeled_x.to(device), + labeled_age.to(device), + labeled_gender.to(device), + labeled_site.to(device), + real_y.to(device), + ) + + if unlabeled_data is not None: + unlabeled_x: torch.Tensor = unlabeled_data.x + unlabeled_age: torch.Tensor = unlabeled_data.age + unlabeled_gender: torch.Tensor = unlabeled_data.gender + unlabeled_site: torch.Tensor = unlabeled_data.d + unlabeled_x, unlabeled_age, unlabeled_gender, unlabeled_site = ( + unlabeled_x.to(device), + unlabeled_age.to(device), + unlabeled_gender.to(device), + unlabeled_site.to(device), + ) + + with torch.enable_grad(): + optimizer.zero_grad() + + labeled_res = self( + labeled_x, labeled_age, labeled_gender, labeled_site + ) + alpha: torch.Tensor = labeled_res["alpha"] + labeled_age_x = labeled_res["age"] + labeled_gender_x = labeled_res["gender"] + pred_y = labeled_res["y"] + labeled_x_mu = labeled_res["x_mu"] + labeled_x_std = labeled_res["x_std"] + labeled_z_mu = labeled_res["z_mu"] + labeled_z_std = labeled_res["z_std"] + labeled_eps = labeled_res["eps"] + labeled_gamma = labeled_res["gamma"] + labeled_delta = labeled_res["delta"] + if unlabeled_data is not None: + unlabeled_res = self( + unlabeled_x, unlabeled_age, unlabeled_gender, unlabeled_site + ) + unlabeled_age_x = unlabeled_res["age"] + unlabeled_gender_x = unlabeled_res["gender"] + unlabeled_x_mu = unlabeled_res["x_mu"] + unlabeled_x_std = unlabeled_res["x_std"] + unlabeled_z_mu = unlabeled_res["z_mu"] + unlabeled_z_std = unlabeled_res["z_std"] + unlabeled_eps = unlabeled_res["eps"] + unlabeled_gamma = unlabeled_res["gamma"] + unlabeled_delta = unlabeled_res["delta"] + age_x = torch.cat((labeled_age_x, unlabeled_age_x), dim=0) + gender_x = torch.cat( + (labeled_gender_x, unlabeled_gender_x), dim=0 + ) + x = torch.cat((labeled_x, unlabeled_x), dim=0) + x_mu = torch.cat((labeled_x_mu, unlabeled_x_mu), dim=0) + x_std = torch.cat((labeled_x_std, unlabeled_x_std), dim=0) + z_mu = torch.cat((labeled_z_mu, unlabeled_z_mu), dim=0) + z_std = torch.cat((labeled_z_std, unlabeled_z_std), dim=0) + eps = torch.cat((labeled_eps, unlabeled_eps), dim=0) + gamma = torch.cat((labeled_gamma, unlabeled_gamma), dim=0) + delta = torch.cat((labeled_delta, unlabeled_delta), dim=0) + else: + age_x = labeled_age_x + gender_x = labeled_gender_x + x = labeled_x + x_mu = labeled_x_mu + x_std = labeled_x_std + z_mu = labeled_z_mu + z_std = labeled_z_std + eps = labeled_eps + gamma = labeled_gamma + delta = labeled_delta + + ce_loss = F.cross_entropy(pred_y, real_y) + rc_loss = F.gaussian_nll_loss(x_mu, x, x_std ** 2, full=True) + kl_loss = kl_divergence_loss( + z_mu, + z_std ** 2, + torch.zeros_like(z_mu), + torch.ones_like(z_std), + ) + + stand_mean = alpha.expand(age_x.size()) + ch_loss = F.gaussian_nll_loss( + gamma, x - stand_mean, delta ** 2, full=True + ) + alpha_loss = F.gaussian_nll_loss( + stand_mean, + x, + x.var(dim=0, keepdim=True).expand(x.size()), + full=True, + ) + + gamma1 = hyperparameters.get("rc_loss", 1) + gamma2 = hyperparameters.get("kl_loss", 1) + gamma3 = hyperparameters.get("ch_loss", 1) + use_alpha_loss = hyperparameters.get("alpha_loss", True) + + if use_alpha_loss: + total_loss = ( + ce_loss + + gamma1 * rc_loss + + gamma2 * kl_loss + + gamma3 * (ch_loss + alpha_loss) + ) + else: + total_loss = ( + ce_loss + + gamma1 * rc_loss + + gamma2 * kl_loss + + gamma3 * ch_loss + ) + total_loss.backward() + clip_grad_norm_(self.parameters(), 0.01) + optimizer.step() + + accuracy = CM.accuracy(real_y, pred_y) + sensitivity = CM.tpr(real_y, pred_y) + specificity = CM.tnr(real_y, pred_y) + precision = CM.ppv(real_y, pred_y) + f1_score = CM.f1_score(real_y, pred_y) + metrics = { + "ce_loss": ce_loss.item(), + "rc_loss": rc_loss.item(), + "kl_loss": kl_loss.item(), + "ch_loss": ch_loss.item(), + "alpha_loss": alpha_loss.item(), + "accuracy": accuracy.item(), + "sensitivity": sensitivity.item(), + "specificity": specificity.item(), + "f1": f1_score.item(), + "precision": precision.item(), + } + return metrics + + def test_step( + self, device: torch.device, test_data: Data + ) -> Dict[str, float]: + self.to(device) + self.eval() + + with torch.no_grad(): + x: torch.Tensor = test_data.x + age: torch.Tensor = test_data.age + gender: torch.Tensor = test_data.gender + site: torch.Tensor = test_data.d + real_y: torch.Tensor = test_data.y + x, age, gender, site, real_y = ( + x.to(device), + age.to(device), + gender.to(device), + site.to(device), + real_y.to(device), + ) + + res = self(x, age, gender, site) + pred_y = res["y"] + x_mu = res["x_mu"] + x_std = res["x_std"] + z_mu = res["z_mu"] + z_std = res["z_std"] + alpha: torch.Tensor = res["alpha"] + age_x: torch.Tensor = res["age"] + gender_x = res["gender"] + eps: torch.Tensor = res["eps"] + gamma: torch.Tensor = res["gamma"] + delta: torch.Tensor = res["delta"] + + ce_loss = F.cross_entropy(pred_y, real_y) + rc_loss = F.gaussian_nll_loss(x_mu, x, x_std ** 2, full=True) + kl_loss = kl_divergence_loss( + z_mu, + z_std ** 2, + torch.zeros_like(z_mu), + torch.ones_like(z_std), + ) + + stand_mean = alpha.expand(age_x.size()) + ch_loss = F.gaussian_nll_loss( + gamma, x - stand_mean, delta ** 2, full=True + ) + alpha_loss = F.gaussian_nll_loss( + stand_mean, + x, + x.var(dim=0, keepdim=True).expand(x.size()), + full=True, + ) + + accuracy = CM.accuracy(real_y, pred_y) + sensitivity = CM.tpr(real_y, pred_y) + specificity = CM.tnr(real_y, pred_y) + precision = CM.ppv(real_y, pred_y) + f1_score = CM.f1_score(real_y, pred_y) + metrics = { + "ce_loss": ce_loss.item(), + "rc_loss": rc_loss.item(), + "kl_loss": kl_loss.item(), + "ch_loss": ch_loss.item(), + "alpha_loss": alpha_loss.item(), + "accuracy": accuracy.item(), + "sensitivity": sensitivity.item(), + "specificity": specificity.item(), + "f1": f1_score.item(), + "precision": precision.item(), + } + return metrics diff --git a/src/models/EDC_VAE.py b/src/models/VAE_FFN.py similarity index 84% rename from src/models/EDC_VAE.py rename to src/models/VAE_FFN.py index 1064863..0753feb 100644 --- a/src/models/EDC_VAE.py +++ b/src/models/VAE_FFN.py @@ -2,8 +2,9 @@ import sys import torch import torch.nn.functional as F -from typing import Any, Optional, Dict, Tuple -from torch.nn import Softmax, Tanh +from torch.nn.utils import clip_grad_norm_ +from typing import Any, Optional, Dict, OrderedDict, Tuple +from torch.nn import Tanh from torch.optim import Optimizer from torch.distributions import Normal from torch_geometric.data import Data @@ -14,6 +15,7 @@ from utils.loss import kl_divergence_loss from utils.metrics import ClassificationMetrics as CM from models.base import ( + LatentSpaceEncoding, ModelBase, FeedForward, VariationalDecoder, @@ -21,7 +23,7 @@ ) -class EDC_VAE(ModelBase): +class VAE_FFN(ModelBase, LatentSpaceEncoding): def __init__( self, input_size: int, @@ -51,7 +53,6 @@ def __init__( emb_size, [h for h in [clf_hidden_1, clf_hidden_2] if h > 0], clf_output_size, - Softmax(dim=1), dropout=dropout, ) @@ -82,6 +83,26 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: "z_std": z_std, } + def ss_forward(self, x: torch.Tensor) -> torch.Tensor: + z_mu, _ = self.encode(x) + y = self.classify(z_mu) + return y + + def ls_forward(self, data: Data) -> torch.Tensor: + x: torch.Tensor = data.x + z_mu, _ = self.encode(x) + return z_mu + + def is_forward(self, data: Data) -> torch.Tensor: + return data.x + + def get_surface(self, z: torch.Tensor) -> torch.Tensor: + y = self.classify(z) + return y + + def get_input_surface(self, x: torch.Tensor) -> torch.Tensor: + return self.ss_forward(x) + def train_step( self, device: torch.device, @@ -133,7 +154,7 @@ def train_step( z_std = labeled_z_std ce_loss = F.cross_entropy(pred_y, real_y) - ll_loss = F.gaussian_nll_loss(x_mu, x, x_std ** 2, full=True) + rc_loss = F.gaussian_nll_loss(x_mu, x, x_std ** 2, full=True) kl_loss = kl_divergence_loss( z_mu, z_std ** 2, @@ -141,10 +162,11 @@ def train_step( torch.ones_like(z_std), ) - gamma1 = hyperparameters.get("ll_loss", 1) + gamma1 = hyperparameters.get("rc_loss", 1) gamma2 = hyperparameters.get("kl_loss", 1) - total_loss = ce_loss + gamma1 * ll_loss + gamma2 * kl_loss + total_loss = ce_loss + gamma1 * rc_loss + gamma2 * kl_loss total_loss.backward() + clip_grad_norm_(self.parameters(), 0.01) #5.0) optimizer.step() accuracy = CM.accuracy(real_y, pred_y) @@ -154,7 +176,7 @@ def train_step( f1_score = CM.f1_score(real_y, pred_y) metrics = { "ce_loss": ce_loss.item(), - "ll_loss": ll_loss.item(), + "rc_loss": rc_loss.item(), "kl_loss": kl_loss.item(), "accuracy": accuracy.item(), "sensitivity": sensitivity.item(), @@ -183,7 +205,7 @@ def test_step( z_std = res["z_std"] ce_loss = F.cross_entropy(pred_y, real_y) - ll_loss = F.gaussian_nll_loss(x_mu, x, x_std ** 2, full=True) + rc_loss = F.gaussian_nll_loss(x_mu, x, x_std ** 2, full=True) kl_loss = kl_divergence_loss( z_mu, z_std ** 2, @@ -199,7 +221,7 @@ def test_step( metrics = { "ce_loss": ce_loss.item(), - "ll_loss": ll_loss.item(), + "rc_loss": rc_loss.item(), "kl_loss": kl_loss.item(), "accuracy": accuracy.item(), "sensitivity": sensitivity.item(), diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..23cb3d4 --- /dev/null +++ b/src/models/__init__.py @@ -0,0 +1,24 @@ +__all__ = [ + "FFN", + "VAE_FFN", + "SHRED", + "SHRED_I", + "SHRED_III", + "count_parameters", +] + +import os +import sys +from torch.nn import Module + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from .FFN import FFN +from .VAE_FFN import VAE_FFN +from .SHRED import SHRED +from .SHRED_I import SHRED_I +from .SHRED_III import SHRED_III + + +def count_parameters(model: Module) -> int: + return sum(p.numel() for p in model.parameters() if p.requires_grad) diff --git a/src/models/base.py b/src/models/base.py index abaef54..12715d4 100644 --- a/src/models/base.py +++ b/src/models/base.py @@ -1,6 +1,13 @@ from __future__ import annotations -from abc import abstractmethod -from typing import Sequence, Optional, Dict, Any +from typing import OrderedDict, Sequence, Optional, Tuple, Dict, Any +from abc import ABC, abstractmethod + +import numpy as np +from captum.attr import IntegratedGradients +from scipy.spatial.distance import squareform +from sklearn.decomposition import PCA, TruncatedSVD +from sklearn.pipeline import Pipeline, make_pipeline +from sklearn.preprocessing import StandardScaler import torch from torch_geometric.data import Data @@ -83,7 +90,114 @@ def forward(self, z: torch.Tensor): return mu, std -class ModelBase(Module): +class SaliencyScoreForward(ABC): + @abstractmethod + def ss_forward(self, *args) -> torch.Tensor: + raise NotImplementedError + + def get_baselines_inputs( + self, data: Data + ) -> Tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor = data.x + y: torch.Tensor = data.y + baselines = x[y == 0].mean(dim=0).view(1, -1) + inputs = x[y == 1] + return baselines, inputs + + def saliency_score(self, data: Data) -> torch.Tensor: + baselines, inputs = self.get_baselines_inputs(data) + ig = IntegratedGradients(self.ss_forward, True) + scores: torch.Tensor = ig.attribute( + inputs=inputs, baselines=baselines, target=1 + ) + + scores = scores.detach().cpu().numpy() + print(scores[0].shape) + scores = np.array([squareform(score) for score in scores]) + return scores + +class LatentSpaceEncoding(ABC): + @abstractmethod + def ls_forward(self, data: Data) -> torch.Tensor: + """ + return z + """ + raise NotImplementedError + + @abstractmethod + def is_forward(self, data: Data) -> torch.Tensor: + """ + return z + """ + raise NotImplementedError + + @abstractmethod + def get_surface(self, z: torch.Tensor) -> torch.Tensor: + """ + return y value for each z + """ + raise NotImplementedError + + @abstractmethod + def get_input_surface(self, x: torch.Tensor) -> torch.Tensor: + """ + return y value for each x + """ + raise NotImplementedError + + @staticmethod + def _prepare_grid( + x: np.ndarray, pipeline: Pipeline, grid_points_dist: float = 0.1 + ) -> Tuple[Tuple[np.ndarray, np.ndarray], np.ndarray]: + min1, max1 = x[:, 0].min() - 1, x[:, 0].max() + 1 + min2, max2 = x[:, 1].min() - 1, x[:, 1].max() + 1 + x1grid = np.arange(min1, max1, grid_points_dist) + x2grid = np.arange(min2, max2, grid_points_dist) + xx, yy = np.meshgrid(x1grid, x2grid) + r1, r2 = xx.flatten(), yy.flatten() + r1, r2 = r1.reshape((len(r1), 1)), r2.reshape((len(r2), 1)) + grid_xt = np.hstack((r1, r2)) + emb_grid = pipeline.inverse_transform(grid_xt) + return (xx, yy), emb_grid + + def get_latent_space_encoding(self, data: Data) -> Dict[str, np.ndarray]: + self.eval() + with torch.no_grad(): + z = self.ls_forward(data).detach().numpy() + pipeline = make_pipeline(StandardScaler(), TruncatedSVD(2, random_state=0)) + x = pipeline.fit_transform(z) + + surface, emb_grid = self._prepare_grid(x, pipeline) + emb_grid = torch.tensor(emb_grid, dtype=data.x.dtype) + zz: np.ndarray = self.get_surface(emb_grid)[:, 1].detach().numpy() + + xx: np.ndarray = surface[0] + yy: np.ndarray = surface[1] + zz = zz.reshape(xx.shape) + return {"x": x, "xx": xx, "yy": yy, "zz": zz} + + def get_input_space_encoding(self, data: Data) -> Dict[str, np.ndarray]: + self.eval() + with torch.no_grad(): + x = self.is_forward(data).detach().numpy() + pipeline = make_pipeline(StandardScaler(), PCA(2, random_state=0)) + x = pipeline.fit_transform(x) + + surface, emb_grid = self._prepare_grid( + x, pipeline, grid_points_dist=1.5 + ) + emb_grid = torch.tensor(emb_grid, dtype=data.x.dtype) + zz: np.ndarray = self.get_input_surface(emb_grid)[ + :, 1 + ].detach().numpy() + + xx: np.ndarray = surface[0] + yy: np.ndarray = surface[1] + zz = zz.reshape(xx.shape) + return {"x": x, "xx": xx, "yy": yy, "zz": zz} + + +class ModelBase(Module, SaliencyScoreForward): def __init__(self): super().__init__() @@ -95,6 +209,20 @@ def get_optimizer(self, param: dict) -> Optimizer: ) return optim + @classmethod + def load_from_state_dict( + cls, + path: str, + model_params: Dict[str, Any], + device: torch.device = torch.device("cpu"), + ) -> ModelBase: + state_dict: OrderedDict[str, torch.Tensor] = torch.load( + path, map_location=device + ) + model = cls(**model_params) + model.load_state_dict(state_dict) + return model + @abstractmethod def train_step( self, diff --git a/src/single_stage_framework.py b/src/single_stage_framework.py new file mode 100644 index 0000000..09f1304 --- /dev/null +++ b/src/single_stage_framework.py @@ -0,0 +1,119 @@ +import os +import json +import time +import yaml +import logging +import argparse +import pandas as pd +from itertools import product +from typing import Any, Dict + +from data import Dataset +from config import EXPERIMENT_DIR, FrameworkConfigParser +from utils import mkdir, on_error, seed_torch +from factory import SingleStageFrameworkFactory +from trainer import SingleStageFrameworkTrainer, TrainerParams + + +@on_error(dict(), True) +def experiment(trainer: SingleStageFrameworkTrainer): + trainer_results = trainer.run() + return trainer_results.to_dict() + + +def process(config: Dict[str, Any]): + seed_torch() + logging.info("CONFIG:\n{}".format(json.dumps(config, indent=4))) + + script_name = os.path.splitext(os.path.basename(__file__))[0] + experiment_name = "{}_{}_{}".format( + script_name, + int(time.strftime('%Y%m%d%H%M%S',time.localtime(time.time())) + str(time.time()).split('.')[-1][:3]), + os.getpid() + ) + + output_dir = ( + config.get("output_directory", EXPERIMENT_DIR) or EXPERIMENT_DIR + ) + output_dir = os.path.abspath(os.path.join(output_dir, experiment_name)) + + config_path = os.path.join( + output_dir, "{}.config.json".format(experiment_name), + ) + results_path = os.path.join(output_dir, "{}.csv".format(experiment_name),) + + mkdir(output_dir) + with open(config_path, "w") as f: + json.dump(config, f, indent=4, sort_keys=True) + + dataloader = SingleStageFrameworkFactory.load_dataloader( + model_name=config["model_name"], + dataloader_param={ + "dataset": Dataset(config["dataset"]), + "harmonize": config["harmonize"], + }, + ) + all_results = list() + + for seed, fold in product(config["seed"], config["fold"]): + trainer = SingleStageFrameworkTrainer( + dataloader=dataloader, + trainer_params=TrainerParams( + output_dir, + config.get("model_name"), + config.get("model_params", dict()), + config.get("optim_params", dict()), + config.get("hyperparameters", dict()), + Dataset(config["dataset"]), + seed, + fold, + config.get("ssl", False), + config.get("harmonize", False), + config.get("validation", False), + config.get("labeled_sites", None), + config.get("unlabeled_sites", None), + config.get("num_unlabeled", None), + config.get("device", -1), + config.get("verbose", False), + config.get("patience", float("inf")), + config.get("max_epoch", 1000), + config.get("save_model", False), + config.get("dataloader_num_process", 10), + ), + ) + result = experiment(trainer) + all_results.append(result) + + logging.info("RESULT:\n{}".format(json.dumps(result, indent=4))) + + df = pd.DataFrame(all_results).dropna(how="all") + if not df.empty: + df.to_csv(results_path, index=False) + + +def main(args): + with open(os.path.abspath(args.config), "r") as f: + configs: Dict[str, Any] = yaml.full_load(f) + + parser: FrameworkConfigParser = FrameworkConfigParser.parse(**configs) + for config in parser.generate(): + process(config) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + type=str, + default="config_templates/single_stage_framework/config.yml", + help="the path to the config file", + ) + args = parser.parse_args() + + logging.basicConfig( + level=logging.DEBUG, + format="[%(asctime)s] - %(filename)s: %(levelname)s: " + "%(funcName)s(): %(lineno)d:\t%(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + main(args) diff --git a/src/trainer.py b/src/trainer.py index 16bd7d0..55e4db5 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -1,27 +1,20 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, Optional, Sequence, Union import os import json import time import copy import logging -import pandas as pd import numpy as np import torch from torch_geometric.data import Data -from data import Dataset -from models.EDC_VAE import EDC_VAE -from utils.misc import ( - get_device, - get_pbar, - mkdir, - on_error, - seed_torch, - count_parameters, -) +from data import DataloaderBase, Dataset +from factory import SingleStageFrameworkFactory +from models import count_parameters +from utils import get_device, get_pbar, mkdir, seed_torch @dataclass(frozen=True) @@ -31,62 +24,73 @@ class TrainerParams: model_params: Dict[str, Any] optim_params: Dict[str, Any] hyperparameters: Dict[str, Any] - dataset_path: str - dataset_name: str - seeds: Sequence[int] - num_fold: int + dataset: Dataset + seed: int + fold: int ssl: bool harmonize: bool + validation: bool labeled_sites: Optional[Union[str, Sequence[str]]] = field(default=None) + unlabeled_sites: Optional[Union[str, Sequence[str]]] = field(default=None) + num_unlabeled: Optional[int] = field(default=None) device: int = field(default=-1) verbose: bool = field(default=False) patience: int = field(default=np.inf) max_epoch: int = field(default=1000) save_model: bool = field(default=False) - time_id: bool = field(init=False, default_factory=lambda: int(time.time())) + dataloader_num_process: int = 1 + time_id: bool = field(init=False, default_factory=lambda: int(time.strftime('%Y%m%d%H%M%S',time.localtime(time.time())) + str(time.time()).split('.')[-1][:3])) # int(time.time())) - @property - def dataset(self) -> Dataset: - return Dataset(self.dataset_path, self.dataset_name, self.harmonize) - - def to_dict(self, seed: int, fold: int) -> Dict[str, Any]: + def to_dict(self) -> Dict[str, Any]: return { "model_name": self.model_name, "model_params": str(self.model_params), "optim_params": str(self.optim_params), "hyperparameters": str(self.hyperparameters), - "dataset": self.dataset_name, - "seed": seed, - "fold": fold, + "dataset": self.dataset.value, + "seed": self.seed, + "fold": self.fold, "ssl": self.ssl, "harmonize": self.harmonize, + "validation": self.validation, "labeled_sites": self.labeled_sites, + "unlabeled_sites": self.unlabeled_sites, "device": self.device, - "epochs_log_path": self.epochs_log_path(seed, fold), + "epochs_log_path": self.epochs_log_path, } - def model_path(self, seed: int, fold: int) -> str: + @property + def model_path(self): return os.path.join( os.path.abspath(self.output_directory), "models", "{}_{}_{}_{}_{}.pt".format( - self.dataset_name, self.model_name, seed, fold, self.time_id, + self.dataset.value, + self.model_name, + self.seed, + self.fold, + self.time_id, ), ) - def epochs_log_path(self, seed: int, fold: int) -> str: + @property + def epochs_log_path(self): return os.path.join( os.path.abspath(self.output_directory), "epochs_log", "{}_{}_{}_{}_{}.log".format( - self.dataset_name, self.model_name, seed, fold, self.time_id, + self.dataset.value, + self.model_name, + self.seed, + self.fold, + self.time_id, ), ) @dataclass(frozen=True) class TrainerResults: - trainer_params_dict: Dict[str, Any] + trainer_params: TrainerParams num_labeled_train: int num_unlabeled_train: int num_valid: int @@ -99,7 +103,7 @@ class TrainerResults: def to_dict(self) -> Dict[str, Any]: return { - **self.trainer_params_dict, + **self.trainer_params.to_dict(), "num_labeled_train": self.num_labeled_train, "num_unlabeled_train": self.num_unlabeled_train, "num_valid": self.num_valid, @@ -114,11 +118,23 @@ def to_dict(self) -> Dict[str, Any]: class Trainer(ABC): def __init__( - self, trainer_params: TrainerParams, + self, dataloader: DataloaderBase, trainer_params: TrainerParams, ): super().__init__() + if dataloader.dataset != trainer_params.dataset: + raise Exception( + "dataloader.dataset != trainer_params.dataset, {} != {}".format( + dataloader.dataset.value, trainer_params.dataset.value + ) + ) + if dataloader.harmonize != trainer_params.harmonize: + raise Exception( + "dataloader.harmonize != trainer_params.harmonize, {} != {}".format( + dataloader.harmonize, trainer_params.harmonize + ) + ) + self.dataloader = dataloader self.trainer_params = trainer_params - self.dataset = trainer_params.dataset self.__called = False def _set_called(self): @@ -149,24 +165,44 @@ def run(self): raise NotImplementedError -class EDC_VAE_Trainer(Trainer): - @on_error(None, True) - def _run_single_seed_fold( - self, seed: int, fold: int, data_dict: Dict[str, Union[Data, int]] - ) -> TrainerResults: +class SingleStageFrameworkTrainer(Trainer): + def run(self) -> TrainerResults: + self._set_called() + seed_torch() device = get_device(self.trainer_params.device) verbose = self.trainer_params.verbose start = time.time() + data_dict = self.dataloader.load_split_data( + seed=self.trainer_params.seed, + fold=self.trainer_params.fold, + ssl=self.trainer_params.ssl, + validation=self.trainer_params.validation, + labeled_sites=self.trainer_params.labeled_sites, + unlabeled_sites=self.trainer_params.unlabeled_sites, + num_unlabeled=self.trainer_params.num_unlabeled, + num_process=self.trainer_params.dataloader_num_process, + ) + num_labeled_train = data_dict.get("num_labeled_train", 0) num_unlabeled_train = data_dict.get("num_unlabeled_train", 0) - num_valid = data_dict.get("num_test", 0) - baseline_accuracy = self._get_baseline_accuracy(data_dict.get("test")) + if self.trainer_params.validation: + num_valid = data_dict.get("num_valid", 0) + baseline_accuracy = self._get_baseline_accuracy( + data_dict.get("valid") + ) + else: + num_valid = data_dict.get("num_test", 0) + baseline_accuracy = self._get_baseline_accuracy( + data_dict.get("test") + ) self.trainer_params.model_params["input_size"] = data_dict["input_size"] self.trainer_params.model_params["num_sites"] = data_dict["num_sites"] - model = EDC_VAE(**self.trainer_params.model_params) + model = SingleStageFrameworkFactory.load_model( + self.trainer_params.model_name, self.trainer_params.model_params + ) model_size = count_parameters(model) optimizer = model.get_optimizer(self.trainer_params.optim_params) @@ -181,7 +217,7 @@ def _run_single_seed_fold( save_model = self.trainer_params.save_model best_model_state_dict = None - epochs_log_path = self.trainer_params.epochs_log_path(seed, fold) + epochs_log_path = self.trainer_params.epochs_log_path mkdir(os.path.dirname(epochs_log_path)) with open(epochs_log_path, "w") as f: f.write("") @@ -196,11 +232,16 @@ def _run_single_seed_fold( optimizer, self.trainer_params.hyperparameters, ) - valid_metrics = model.test_step( - device, data_dict.get("test", None) - ) + if self.trainer_params.validation: + valid_metrics = model.test_step( + device, data_dict.get("valid", None) + ) + else: + valid_metrics = model.test_step( + device, data_dict.get("test", None) + ) except Exception as e: - logging.error(e) + logging.error(e, exc_info=True) with open(epochs_log_path, "a") as f: f.write( json.dumps( @@ -210,11 +251,6 @@ def _run_single_seed_fold( + "\n" ) - """ - save priority: - 1. accuracy - 2. ce_loss - """ save = valid_metrics["accuracy"] > best_metrics["accuracy"] or ( valid_metrics["accuracy"] == best_metrics["accuracy"] and valid_metrics["ce_loss"] < best_metrics["ce_loss"] @@ -237,7 +273,7 @@ def _run_single_seed_fold( if save_model and best_model_state_dict is not None: try: - model_path = self.trainer_params.model_path(seed, fold) + model_path = self.trainer_params.model_path mkdir(os.path.dirname(model_path)) torch.save(best_model_state_dict, model_path) except Exception as e: @@ -248,7 +284,7 @@ def _run_single_seed_fold( end = time.time() return TrainerResults( - trainer_params_dict=self.trainer_params.to_dict(seed, fold), + trainer_params=self.trainer_params, num_labeled_train=num_labeled_train, num_unlabeled_train=num_unlabeled_train, num_valid=num_valid, @@ -259,35 +295,3 @@ def _run_single_seed_fold( model_size=model_size, model_path=model_path, ) - - def run(self, results_csv_path: str) -> List[Dict[str, Any]]: - self._set_called() - all_seed_fold_results = list() - - for seed in self.trainer_params.seeds: - data_dict_generator = self.dataset.load_split_data( - seed=seed, - num_fold=self.trainer_params.num_fold, - ssl=self.trainer_params.ssl, - labeled_sites=self.trainer_params.labeled_sites, - ) - - for fold, data_dict in enumerate(data_dict_generator): - fold_result: Optional[ - TrainerResults - ] = self._run_single_seed_fold(seed, fold, data_dict) - if fold_result is None: - continue - - all_seed_fold_results.append(fold_result.to_dict()) - logging.info( - "RESULT:\n{}".format( - json.dumps(all_seed_fold_results[-1], indent=4) - ) - ) - - df = pd.DataFrame(all_seed_fold_results).dropna(how="all") - if not df.empty: - df.to_csv(results_csv_path, index=False) - - return all_seed_fold_results diff --git a/src/utils/misc.py b/src/utils/__init__.py similarity index 74% rename from src/utils/misc.py rename to src/utils/__init__.py index ccedc9c..b5f3ebd 100644 --- a/src/utils/misc.py +++ b/src/utils/__init__.py @@ -3,10 +3,10 @@ import random import logging import traceback +import subprocess import numpy as np import torch from tqdm import tqdm -from torch.nn import Module def mkdir(path): @@ -16,7 +16,7 @@ def mkdir(path): def on_error(value, print_error_stack=True): """ - returns a wrapper which catches error within a function + returns a wrapper which catches error within a function and returns a default value on error value: the default value to be returned when error occured """ @@ -83,5 +83,18 @@ def get_pbar(max_epoch, verbose): return epoch_gen(max_epoch) -def count_parameters(model: Module) -> int: - return sum(p.numel() for p in model.parameters() if p.requires_grad) +def get_gpu(num_polling: int = 20): + polled_memory_free = list() + for _ in range(num_polling): + command = "nvidia-smi --query-gpu=memory.free --format=csv" + memory_free_info = ( + subprocess.check_output(command.split()) + .decode("ascii") + .split("\n")[:-1][1:] + ) + memory_free_values = [ + int(x.split()[0]) for i, x in enumerate(memory_free_info) + ] + polled_memory_free.append(memory_free_values) + gpu = np.argmax(np.mean(memory_free_values, axis=0)) + return get_device(gpu) diff --git a/src/utils/data.py b/src/utils/data.py new file mode 100644 index 0000000..0034735 --- /dev/null +++ b/src/utils/data.py @@ -0,0 +1,30 @@ +import torch +import numpy as np + +from tqdm import tqdm +from scipy.spatial import distance +from sklearn.preprocessing import LabelEncoder +from joblib import Parallel, delayed +from torch_sparse.tensor import SparseTensor +from torch_geometric.data import Data +from torch_geometric.loader import DataLoader + + +def corr_mx_flatten(X): + """ + returns upper triangluar matrix of each sample in X + + option 1: + X.shape == (num_sample, num_feature, num_feature) + X_flattened.shape == (num_sample, num_feature * (num_feature - 1) / 2) + + option 2: + X.shape == (num_feature, num_feature) + X_flattend.shape == (num_feature * (num_feature - 1) / 2,) + """ + upper_triangular_idx = np.triu_indices(X.shape[1], 1) + if len(X.shape) == 3: + X = X[:, upper_triangular_idx[0], upper_triangular_idx[1]] + else: + X = X[upper_triangular_idx[0], upper_triangular_idx[1]] + return X diff --git a/src/utils/loss.py b/src/utils/loss.py index 8619c6c..3fa3d5b 100644 --- a/src/utils/loss.py +++ b/src/utils/loss.py @@ -22,3 +22,15 @@ def kl_divergence_loss( kl = 0.5 * (var2.log() - var1.log() + (var1 + (mu1 - mu2) ** 2) / var2 - 1) kl = kl.sum(dim=1) return reduce(kl, reduction) + + +def entropy_loss( + pred_y: torch.Tensor, reduction: str = "mean", +) -> torch.Tensor: + eps = 1e-12 + pred_y = torch.maximum(pred_y, torch.tensor(eps)) + uni_dist = torch.ones(pred_y.size(0), device=pred_y.device) / pred_y.size(1) + max_entropy = -uni_dist.log() + entropy = -pred_y * pred_y.log() + entropy = entropy.sum(dim=1) + return reduce(max_entropy - entropy, reduction=reduction) diff --git a/src/utils/metrics.py b/src/utils/metrics.py index fd01718..8fc42bf 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -3,6 +3,8 @@ class ClassificationMetrics: """ + https://www.labtestsonline.org.au/understanding/test-accuracy-and-reliability/how-reliable-is-pathology-testing + Only for 2 class classification (diseased and controlled). - class 1 = diseased - class 0 = controlled @@ -167,3 +169,86 @@ def f1_score(y_true, y_pred) -> torch.Tensor: fp = ((y_true == 0) & (y_pred == 1)).float().sum() fn = ((y_true == 1) & (y_pred == 0)).float().sum() return 2 * tp / (2 * tp + fp + fn) + + +class CummulativeClassificationMetrics: + """ + Only for 2 class classification (diseased and controlled). + - class 1 = diseased + - class 0 = controlled + """ + + def __init__(self): + self.reset() + + def reset(self): + self.tp = 0 + self.tn = 0 + self.fp = 0 + self.fn = 0 + self.total = 0 + + @staticmethod + def _check_y(y): + assert 1 <= y.ndim <= 2, "y.ndim must be 1 or 2, but given {}".format( + y.ndim + ) + if y.ndim == 2: + assert ( + y.size(1) == 2 + ), "dim 1 of y must have size 2, but given size {}".format( + y.size(1) + ) + y = y.argmax(dim=1) + elif y.ndim == 1: + assert torch.all((y == 0) | (y == 1)), "y can only contain 0 or 1" + return y + + def update_batch(self, y_true, y_pred): + y_pred = self._check_y(y_pred) + y_true = self._check_y(y_true) + self.tp += ((y_true == 1) & (y_pred == 1)).float().sum() + self.tn += ((y_true == 0) & (y_pred == 0)).float().sum() + self.fp += ((y_true == 0) & (y_pred == 1)).float().sum() + self.fn += ((y_true == 1) & (y_pred == 0)).float().sum() + self.total += y_true.size(0) + + @property + def accuracy(self) -> torch.Tensor: + return (self.tp + self.tn) / self.total + + @property + def tnr(self) -> torch.Tensor: + return self.tn / (self.tn + self.fp) + + @property + def tpr(self) -> torch.Tensor: + return self.tp / (self.tp + self.fn) + + @property + def ppv(self) -> torch.Tensor: + return self.tp / (self.tp + self.fp) + + @property + def npv(self) -> torch.Tensor: + return self.tn / (self.tn + self.fn) + + @property + def fpr(self) -> torch.Tensor: + return self.fp / (self.tn + self.fp) + + @property + def fnr(self) -> torch.Tensor: + return self.fn / (self.tp + self.fn) + + @property + def fdr(self) -> torch.Tensor: + return self.fp / (self.tp + self.fp) + + @property + def fomr(self) -> torch.Tensor: + return self.fn / (self.tn + self.fn) + + @property + def f1_score(self) -> torch.Tensor: + return 2 * self.tp / (2 * self.tp + self.fp + self.fn)