Skip to content

Commit

Permalink
adding demo details
Browse files Browse the repository at this point in the history
  • Loading branch information
theofanis-insitro committed Oct 20, 2023
1 parent ea89fba commit 8a35f47
Show file tree
Hide file tree
Showing 3 changed files with 533 additions and 0 deletions.
13 changes: 13 additions & 0 deletions demo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
## Model Training Demo

We provide the following basic demo as a starting point to training perturbation models using this repo.

To train a SAMS-VAE model on the Replogle dataset, run:
`python ../train.py --config sams_vae_replogle.yaml`

The example config file, `sams_vae_replogle.yaml`, has been annotated with additional explanation regarding config structure
and can be used as a starting point for setting up new training runs.

As the model trains, the training metrics and the model checkpoints can be visualized using `visualize_results.ipynb`

For examples of training model sweeps, see the sweep configs and instructions in `../paper/experiments/`
61 changes: 61 additions & 0 deletions demo/sams_vae_replogle.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Experiment hyperparameters
name: sams_vae_replogle
seed: 0
max_epochs: 1000

# WandB hyperparameters
# can set to True to save metrics and checkpoints to WandB
use_wandb: False
wandb_kwargs.name: sams_vae_replogle_filtered_example
wandb_kwargs.project: debug

# Data module class + hyperparameters
# current data module options are:
# - ReplogleDataModule
# - NormanOODCombinationDataModule
# - NormanDataEfficiency
# - NormanDataEfficiencyDataModule
# - SAMSVAESimulationDataModule
# data_module_kwargs are the arguments for the __init__
# function in the corresponding data module classes (see sams_vae/data/)
data_module: ReplogleDataModule
data_module_kwargs.batch_size: 512

# Model class + hyperparameters
# current model options are:
# - SAMSVAEModel
# - CPAVAEModel
# - SVAEPlusModel
# - ConditionalVAEModel
# model_kwargs are the arguments for the __init__ function
# of the corresponding model classes (see sams_vae/models/)
model: SAMSVAEModel
model_kwargs.n_latent: 100
model_kwargs.mask_prior_prob: 0.01
model_kwargs.embedding_prior_scale: 1
model_kwargs.likelihood_key: library_nb
model_kwargs.decoder_n_layers: 1
model_kwargs.decoder_n_hidden: 350

# Guide class + hyperparameters
# Like the models, these correspond to the guide classes
# in sams_vae/models/
# Note that the guide must match the model (will be under the same subdirectory,
# eg sams_vae/models/sams_vae/ or sams_vae/models/cpa_vae/)
guide: SAMSVAEMeanFieldNormalGuide
guide_kwargs.n_latent: 100
guide_kwargs.basal_encoder_n_layers: 1
guide_kwargs.basal_encoder_n_hidden: 350
guide_kwargs.basal_encoder_input_normalization: log_standardize

# Loss module class + hyperparameters
# from same subdirectory as model and guide
loss_module: SAMSVAE_ELBOLossModule

# Lightning module hyperparameters
lightning_module_kwargs.lr: 0.001
lightning_module_kwargs.n_particles: 5

# Predictor class + hyperparameters (used to evaluation)
# also much match model / guide, in same subdirectory
predictor: SAMSVAEPredictor
459 changes: 459 additions & 0 deletions demo/visualize_results.ipynb

Large diffs are not rendered by default.

0 comments on commit 8a35f47

Please sign in to comment.