HiDiNet is a state-of-the-art deep learning framework for modeling healthcare time series data, combining Variational Autoencoders (VAE), Stochastic Differential Equations (SDE), and Transformer architectures to predict health trajectories and survival probabilities.
HiDiNet integrates multiple advanced deep learning components:
- Variational Autoencoders (VAE) with Normalizing Flows for latent representation learning
- Stochastic Differential Equations (SDE) for modeling disease dynamics and temporal evolution
- Transformer Architecture for capturing long-range temporal dependencies
- Survival Analysis components for mortality risk prediction
- Clone the repository:
git clone <repository-url>
cd HiDiNet- Install dependencies:
pip install -r requirements.txtThis project uses the English Longitudinal Study of Ageing (ELSA) dataset. To obtain the data:
- Visit the ELSA website
- Register for data access
- Download the required ELSA data files
- Place the data files in the appropriate folders as described below
-
Create Required Folders:
mkdir Data mkdir Data/ELSA mkdir Parameters
-
Prepare Data:
- Place ELSA data files in the
Data/ELSA/folder - Run the data parser to generate processed files:
cd Data_Parser ./run_parser.shAll generated files will be saved in the
Data/folder. - Place ELSA data files in the
-
Train the Model:
python train.py --job_id <unique_id> --batch_size 800 --niters 1000
-
Generate Predictions:
python predict.py --job_id <unique_id>
-
Create Visualizations:
cd Plotting_code python <plotting_script>.py
-
Navigate to Transformer Directory:
cd transformer -
Create Required Folders:
mkdir Analysis_Data_elsa mkdir Output_elsa mkdir Parameters_elsa
-
Train the Transformer Model:
python train.py --job_id <unique_id> --batch_size 800 --niters 1000
-
Generate Predictions:
python predict.py --job_id <unique_id>
-
Create Visualizations:
cd Plotting python <plotting_script>.py
- Mixed Precision Training: Optimized for faster training and reduced memory usage
- KL Divergence Scheduling: Gradual increase of KL terms for training stability
- Monte Carlo Simulations: For uncertainty quantification in predictions
- Autoregressive Decoding: Sequential prediction in transformer decoder
- Memory-Efficient Processing: Optimized for large healthcare datasets
- Comprehensive Evaluation: C-index, Brier scores, and longitudinal RMSE metrics
HiDiNet combines several advanced components:
- VAE with Normalizing Flows: Handles missing data and learns latent representations
- SDE Dynamics: Models temporal evolution of health variables
- Transformer Encoder: Captures complex temporal dependencies
- Transformer Decoder: Generates survival predictions autoregressively
- Memory Networks: Maintains patient-specific context
- Survival Analysis: Predicts mortality risk and survival probabilities
batch_size: Training batch size (default: 800)learning_rate: Learning rate for optimizationniters: Number of training iterationscorruption: Data corruption rate for robustnessgamma_size: Size of gamma parametersz_size: Latent space dimensionalitydecoder_size: Decoder network sizeNflows: Number of normalizing flow layers
- Gradient clipping for training stability
- Learning rate scheduling with ReduceLROnPlateau
- Weights & Biases integration for experiment tracking
- Mixed precision training with automatic mixed precision (AMP)
The model is evaluated using multiple metrics:
- C-index: Concordance index for survival prediction accuracy
- Brier Score: Calibration measure for survival predictions
- Longitudinal RMSE: Root mean square error for trajectory predictions
- Survival Probability: Time-varying survival estimates
We welcome contributions! Please feel free to submit issues, feature requests, or pull requests.
For questions or support, please contact [yzhang@trinity.edu]