This is the official code repository for the project:
A Deep Representation Learning Model to Predict Response to Vagus Nerve Stimulation
Hrishikesh Suresh MD, et al., George M Ibrahim MD PhD FRCSC
VQ-VAE based predictive modelling of VNS Response
- Overview
- [Demo] (#demo)
- Repository Structure
- Setup & Installation
- Data Preparation
- Single Subject Inference
- Training the VQ-VAE
- Using the Trained VQ-VAE
- Training the SVM Classifier
- Evaluating & Using the SVM
- Utilities & Preprocessing
- Weights & Checkpoints
This repository provides a framework for predictive modeling of VNS (Vagus Nerve Stimulation) response using a Vector Quantized Variational Autoencoder (VQ-VAE) and an SVM classifier. It includes scripts for data preprocessing, model training, evaluation, and utilities for MRI intensity normalization.
To run a demo of the inference pipeline, please see the demo jupyter notebook under the demo folder. All instruction to run the notebook are inside the notebook file.
configs/— YAML configuration files for VQ-VAE and SVM training/evaluationdatasets/— Data loading and preprocessing modulesevaluation/— Scripts for evaluating reconstructions and classifier performancemodels/— Model definitions (VQ-VAE, quantizers, etc.)training/— Training scripts for VQ-VAE and SVMutils/— Utility scripts (metrics, loss functions, normalization, etc.)weights/— Directory for saving trained model checkpointspreprocessing/— Scripts for MRI preprocessing
- Clone the repository
git clone <repo_url>
cd VQVNS- Install dependencies using uv and pyproject.toml
Install all dependencies directly from
pyproject.toml:
uv pip install -r pyproject.tomlNote: uv is a fast Python package installer. Install it with
pip install uvif not already available.
Preprocessing your MRI data is required before training or inference.
The local preprocessing pipeline is intended to be run locally
- Use the
utils/preprocessing/generic_t1_preprocess.pyscript to preprocess your MRI data. - Specify the following when running the script:
- Path to the input MRI images
- Output directory for preprocessed files
- Path to the MNI 1mm brain reference template (included with FSL v6)
- Example usage:
python utils/preprocessing/generic_t1_preprocess.py \ --input_dir path/to/raw_mri/ \ --output_dir path/to/preprocessed_mri/ \ --mni_template path/to/MNI152_T1_1mm_brain.nii.gz
- Ensure you have FSL v6 installed to access the required MNI template.
- Freesurfer is also required for resampling aspects of preprocessing
Note
All downstream training and inference scripts expect preprocessed and cropped MRI files as input. Do not use raw MRI data directly.
The input MRI must be processed with the provided pipeline and should be the cropped output.
Download weights file here. Please open an issue if link does not work as OneDrive links expire every 30 days.
Warning
We are unable to provide images to test against as confidential patient data cannot be shared.
Example usage:
python infer_single_subject.py \
--mri path/to/subject_cropped.nii.gz \
--vqvae_checkpoint weights/vqvae.ckpt \
--svm_weights weights/svm_weights.pkl \
--features_to_keep weights/features_to_keep.npy- The script will automatically pad or crop the MRI to 176x208x176 if needed.
- The output will be a prediction for the subject.
Note
Ensure your input MRI is preprocessed and cropped using the project's pipeline before running inference.
- Configure training
- Edit or use an existing config in
configs/vqvae/(e.g.,vqvae_healthy_gradnorm_hpc.yaml).
- Edit or use an existing config in
- Start training
python training/train_vqvae.py --config configs/vqvae/vqvae_healthy_gradnorm_hpc.yaml- Checkpoint and intermediate file locations are specified in the config file.
- To evaluate or reconstruct using a trained VQ-VAE: Download weights file here
python evaluation/eval_vqvae.py --config configs/vqvae/vqvae_healthy_gradnorm_hpc.yaml --checkpoint weights/vqvae.ckpt- To evaluate reconstruction fidelity:
python evaluation/eval_recon_fidelity.py --config configs/vqvae/vqvae_healthy_gradnorm_hpc.yaml --checkpoint weights/vqvae.ckpt- Configure training
- Edit or use an existing config in
configs/classifier/(e.g.,vns_svm.yaml).
- Edit or use an existing config in
- Start training
python training/train_svm_classifier.py --config configs/classifier/vns_svm.yaml- The classifier will be trained on VQ-VAE latent codes or other features as specified.
- To evaluate the SVM classifier:
python evaluation/eval_svm.py --config configs/classifier/vns_svm.yaml- Adjust the config file to point to the correct feature and label files as needed.
- MRI Intensity Normalization:
- Utilities for normalization are in
utils/packages/intensity-normalization/. - See the official documentation for usage.
- Utilities for normalization are in
- Other utilities:
utils/metrics.py,utils/loss_functions.py, etc., provide supporting functions for training and evaluation.
- Trained model weights and checkpoints are saved in the
weights/directory. - Use these checkpoints for evaluation or further training.
- For SVM evaluation and prediction, ensure the following files are present in
weights/:svm_weights.pkl: Trained SVM model weightsfeatures_to_keep.npy: Numpy array specifying the features to use for SVM classification
- Update your config files to reference these files as needed for SVM evaluation and prediction.
For further details, refer to the code and configuration files in each directory. For questions or issues, please open an issue in the repository.