Comparison of different transfer learning Gaussian process methods on synthetic data. This comparison was conducted for the paper Transfer Learning Bayesian Optimization to Design Competitor DNA Molecules for Use in Diagnostic Assays [1].
To clone this repo use the command:
git clone https://github.com/RSedgwick/TLGPs.git
The code is written in Python 3.6. To install the required packages, ensure conda is installed and run the following command in the root directory of the project:
conda env create -f environment.yml
In this repo, we run experiments to compare four different transfer learning methods:
- Independent multioutput Gaussian process (MOGP) [2]
- Average Gaussian process where all data is considered to be from the same surface (AvgGP)
- Linear Model of Coregionalisation (LMC) [2]
- Latent Variable Multioutput Gaussian Process (LVMOGP) [3]
We run experiments for all these methods on three different test function scenarios based on situations in which we expect each model to perform well:
- Unrelated test functions
- if there is no negative transfer we would expect the MOGP, LMC and LVMOGP to perform similarly here
- Linearly-related test functions
- We expect the LMC and LVMOGP to outperform the MOGP here
- Non-linearly-related test functions
- We expect the LVMOGP to outperform the MOGP and LMC here We expect the MOGP, LMC and LVMOGP to outperform the AvgGP on all test scenarios.
Below is a plot of the mean of the root mean squared error (RMSE) and negative log predictive density (NLPD) for each of the methods for three different test function scenarios for one seed. For each scenario, we have 10 new surfaces being learnt and 5 different random data sets. This plot appears in Figure 4 of Sedgwick et al. [1]
Below is a gif of the predictions of each of the models for one test scenario, where one random datapoint is added each time.
modelsinitializations.py- Contains the initialization functions for the different transfer learning methodslvmogp.py- Contains the code for the latent variable multi-output Gaussian process model, adapted from the GPflow Bayesian GPLVM code [4]test_functions.py- Code for generating the synthetic data
utilsutils.py- Useful functions for initialising models, fitting them, getting performance metrics and saving resultsplotting_utils.py- Useful functions for plotting functions and performance metricsanalysis_utils.py- Useful functions for analysing the results
notebooksComparing_LVMOGP_Prediction_Methods.ipynb- Notebook for comparing different methods for prediction using the LVMOGPfitting_all_models.ipynb- Notebook for fitting all the models and saving plots of the predictionslmc_fitting_and_intialisation.ipynb- Notebook demonstrating the fitting of the LMClmc_setting_W_and_kappa.ipynb- Notebook demonstrating how the LMC can recreate the independent MOGP
analysisplots- various plots that have been generatedmodel_comparison.ipynb- Notebook comparing the RMSE and NLPD of the different models from the many learning curve runsplot_predictions.ipynb- plot the predictions of the models for a given seed, data seed and number of training points Also, plot the log marginal likelihood of the different initialisations at each number of training points for all modelsplot_predictions.py- plot the predictions for all runs, to be made into gifsanimating_plots.ipynb- notebook for making gifs out of predictions
experimentslearning_curves.py- this script is used for fitting each of the models, analysing the results and saving them to a file- The
.pbsscripts can be used to run this many times for different seeds and number of training points on a cluster
When using the code in this repository, please reference our journal paper:
@article{sedgwick_transfer_2023,
title={Transfer Learning Bayesian Optimization for Competitor DNA Molecule Design for Use in Diagnostic Assays},
author={Sedgwick, Ruby and Goertz, John and Stevens, Molly and Misener, Ruth and van der Wilk, Mark},
journal={},
volume={},
pages={},
year={},
publisher={}
}
This work was supported by the UKRI CDT in AI for Healthcare Grant No. EP/S023283/1
