-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ea89fba
commit 8a35f47
Showing
3 changed files
with
533 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
## Model Training Demo | ||
|
||
We provide the following basic demo as a starting point to training perturbation models using this repo. | ||
|
||
To train a SAMS-VAE model on the Replogle dataset, run: | ||
`python ../train.py --config sams_vae_replogle.yaml` | ||
|
||
The example config file, `sams_vae_replogle.yaml`, has been annotated with additional explanation regarding config structure | ||
and can be used as a starting point for setting up new training runs. | ||
|
||
As the model trains, the training metrics and the model checkpoints can be visualized using `visualize_results.ipynb` | ||
|
||
For examples of training model sweeps, see the sweep configs and instructions in `../paper/experiments/` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Experiment hyperparameters | ||
name: sams_vae_replogle | ||
seed: 0 | ||
max_epochs: 1000 | ||
|
||
# WandB hyperparameters | ||
# can set to True to save metrics and checkpoints to WandB | ||
use_wandb: False | ||
wandb_kwargs.name: sams_vae_replogle_filtered_example | ||
wandb_kwargs.project: debug | ||
|
||
# Data module class + hyperparameters | ||
# current data module options are: | ||
# - ReplogleDataModule | ||
# - NormanOODCombinationDataModule | ||
# - NormanDataEfficiency | ||
# - NormanDataEfficiencyDataModule | ||
# - SAMSVAESimulationDataModule | ||
# data_module_kwargs are the arguments for the __init__ | ||
# function in the corresponding data module classes (see sams_vae/data/) | ||
data_module: ReplogleDataModule | ||
data_module_kwargs.batch_size: 512 | ||
|
||
# Model class + hyperparameters | ||
# current model options are: | ||
# - SAMSVAEModel | ||
# - CPAVAEModel | ||
# - SVAEPlusModel | ||
# - ConditionalVAEModel | ||
# model_kwargs are the arguments for the __init__ function | ||
# of the corresponding model classes (see sams_vae/models/) | ||
model: SAMSVAEModel | ||
model_kwargs.n_latent: 100 | ||
model_kwargs.mask_prior_prob: 0.01 | ||
model_kwargs.embedding_prior_scale: 1 | ||
model_kwargs.likelihood_key: library_nb | ||
model_kwargs.decoder_n_layers: 1 | ||
model_kwargs.decoder_n_hidden: 350 | ||
|
||
# Guide class + hyperparameters | ||
# Like the models, these correspond to the guide classes | ||
# in sams_vae/models/ | ||
# Note that the guide must match the model (will be under the same subdirectory, | ||
# eg sams_vae/models/sams_vae/ or sams_vae/models/cpa_vae/) | ||
guide: SAMSVAEMeanFieldNormalGuide | ||
guide_kwargs.n_latent: 100 | ||
guide_kwargs.basal_encoder_n_layers: 1 | ||
guide_kwargs.basal_encoder_n_hidden: 350 | ||
guide_kwargs.basal_encoder_input_normalization: log_standardize | ||
|
||
# Loss module class + hyperparameters | ||
# from same subdirectory as model and guide | ||
loss_module: SAMSVAE_ELBOLossModule | ||
|
||
# Lightning module hyperparameters | ||
lightning_module_kwargs.lr: 0.001 | ||
lightning_module_kwargs.n_particles: 5 | ||
|
||
# Predictor class + hyperparameters (used to evaluation) | ||
# also much match model / guide, in same subdirectory | ||
predictor: SAMSVAEPredictor |
Large diffs are not rendered by default.
Oops, something went wrong.