A framework for training diffusion models with stable, self-consistent scores near the data distribution.
We have implemented a minimal working version of our paper in a notebook with jax and PyTorch. It can be used to create a simular figure to what you can see above, and also do molecular dynamics (MD) simulation with a diffusion model on the Müller-Brown potential. If you don't have a strong opinion, use the jax version, as it is much faster and closer to the implementation found in this repository.
All dependencies are managed with pixi, ensuring fully reproducible environments. To set up the environment, simply run:
pixi install --frozen
Prefer using your own toolchain like conda?
Check out our pyproject.toml
and install the dependencies with your preferred tool.
In that case, replace pixi run
with standard Python execution, for example:
python train.py
We use Hydra for configuration management. You can override any config via command-line arguments or config files.
Train on example toy systems using the provided configurations:
pixi run train dataset=double_well +architecture=mlp/small_score
pixi run train dataset=double_well_2d +architecture=mlp/small_score
Outputs will be saved to the outputs/
directory.
Below, we list the exact commands used to generate the main results from the paper.
Note that in this version of the repository, we use slightly different names for the parameters, e.g., we use
pixi run train dataset=mueller_brown +architecture=mlp/small_potential \
training_schedule.epochs=180 \
training_schedule.losses.0.time_weighting.midpoint=0.5
pixi run train dataset=mueller_brown +architecture=mlp/small_mixture_fair \
weighting_function=ranged \
training_schedule=vp_three_models \
training_schedule.epochs.0=30 \
training_schedule.epochs.1=30 \
training_schedule.epochs.2=120 \
training_schedule.losses.2.time_weighting.midpoint=0.05
pixi run train dataset=mueller_brown +architecture=mlp/small_potential \
training_schedule.epochs=180 \
training_schedule.losses.0.loss.div_est=2 \
training_schedule.losses.0.loss.beta=0.0005 \
+training_schedule.losses.0.loss.residual_fp=True \
training_schedule.losses.0.loss.partial_t_approx=True \
+training_schedule.losses.0.loss.single_gamma=True \
training_schedule.losses.0.time_weighting.midpoint=0.5
pixi run train dataset=mueller_brown +architecture=mlp/small_mixture_fair \
weighting_function=ranged \
training_schedule=vp_three_models \
training_schedule.losses.2.loss.div_est=2 \
training_schedule.losses.2.loss.beta=0.0005 \
+training_schedule.losses.2.loss.residual_fp=True \
training_schedule.losses.2.loss.partial_t_approx=True \
+training_schedule.losses.2.loss.single_gamma=True \
training_schedule.epochs.0=30 \
training_schedule.epochs.1=30 \
training_schedule.epochs.2=120 \
training_schedule.losses.2.time_weighting.midpoint=0.05
pixi run train dataset=aldp \
dataset.limit_samples=50_000 \
dataset.validation=False \
+architecture=transformer/potential \
training_schedule.epochs.0=10000 \
training_schedule.BS=1024 \
checkpoint_options.save_interval_steps=1000 \
training_schedule.losses.0.loss.alpha=0 \
training_schedule.losses.0.loss.beta=0 \
+training_schedule/augment=random_rotations \
dataset.coarse_graining_level=full \
evaluation.num_parallel_langevin_samples=100 \
evaluation.langevin_dt=2e-3 \
evaluation.num_langevin_intermediate_steps=50 \
evaluation.num_langevin_samples=120000 \
evaluation.eval_t=1e-5 \
training_schedule.losses.0.loss.div_est=2 \
+training_schedule.losses.0.loss.residual_fp=True \
training_schedule.losses.0.loss.partial_t_approx=True \
+training_schedule.losses.0.loss.single_gamma=True \
wandb.enabled=True \
+wandb.name=aldp-baseline
pixi run train dataset=aldp \
dataset.limit_samples=50_000 \
dataset.validation=False \
training_schedule=vp_three_models \
+architecture=transformer/score_score_potential \
weighting_function=ranged \
training_schedule.epochs.0=1000 \
training_schedule.epochs.1=2000 \
training_schedule.epochs.2=7000 \
training_schedule.BS=1024 \
checkpoint_options.save_interval_steps=100000 \
training_schedule.losses.2.loss.alpha=0 \
+training_schedule/augment=random_rotations \
dataset.coarse_graining_level=full \
evaluation.num_parallel_langevin_samples=100 \
evaluation.langevin_dt=2e-3 \
evaluation.num_langevin_intermediate_steps=50 \
evaluation.num_langevin_samples=120000 \
evaluation.eval_t=1e-5 \
training_schedule.losses.2.loss.div_est=2 \
+training_schedule.losses.2.loss.residual_fp=True \
training_schedule.losses.2.loss.partial_t_approx=True \
+training_schedule.losses.2.loss.single_gamma=True \
training_schedule.losses.2.loss.beta=0 \
training_schedule.losses.2.time_weighting.midpoint=0.03 \
wandb.enabled=True \
+wandb.name=aldp-mixture
pixi run train dataset=aldp \
dataset.limit_samples=50_000 \
dataset.validation=False \
+architecture=transformer/potential \
training_schedule.epochs.0=10000 \
training_schedule.BS=1024 \
checkpoint_options.save_interval_steps=1000 \
training_schedule.losses.0.loss.alpha=0 \
training_schedule.losses.0.loss.beta=0.0005 \
+training_schedule/augment=random_rotations \
dataset.coarse_graining_level=full \
evaluation.num_parallel_langevin_samples=100 \
evaluation.langevin_dt=2e-3 \
evaluation.num_langevin_intermediate_steps=50 \
evaluation.num_langevin_samples=120000 \
evaluation.eval_t=1e-5 \
training_schedule.losses.0.loss.div_est=2 \
+training_schedule.losses.0.loss.residual_fp=True \
training_schedule.losses.0.loss.partial_t_approx=True \
+training_schedule.losses.0.loss.single_gamma=True \
wandb.enabled=True \
+wandb.name=aldp-fp-0.0005
pixi run train dataset=aldp \
dataset.limit_samples=50_000 \
dataset.validation=False \
training_schedule=vp_three_models \
+architecture=transformer/score_score_potential \
weighting_function=ranged \
training_schedule.epochs.0=1000 \
training_schedule.epochs.1=2000 \
training_schedule.epochs.2=7000 \
training_schedule.BS=1024 \
checkpoint_options.save_interval_steps=100000 \
training_schedule.losses.2.loss.alpha=0 \
+training_schedule/augment=random_rotations \
dataset.coarse_graining_level=full \
evaluation.num_parallel_langevin_samples=100 \
evaluation.langevin_dt=2e-3 \
evaluation.num_langevin_intermediate_steps=50 \
evaluation.num_langevin_samples=120000 \
evaluation.eval_t=1e-5 \
training_schedule.losses.2.loss.div_est=2 \
+training_schedule.losses.2.loss.residual_fp=True \
training_schedule.losses.2.loss.partial_t_approx=True \
+training_schedule.losses.2.loss.single_gamma=True \
training_schedule.losses.2.loss.beta=0.0001 \
training_schedule.losses.2.time_weighting.midpoint=0.03 \
wandb.enabled=True \
+wandb.name=aldp-both-fp-0.0001
pixi run train dataset=minipeptides \
+architecture=transformer/large_potential \
training_schedule.epochs=120 \
training_schedule.BS=1024 \
checkpoint_options.save_interval_steps=10 \
+training_schedule/augment=random_rotations \
evaluation.inference_bs=1024 \
evaluation.fp_inference_bs=256 \
"evaluation.limit_inference_peptides=[KA, RP]" \
training_schedule.losses.0.loss.beta=0.000 \
training_schedule.losses.0.loss.div_est=2 \
+training_schedule.losses.0.loss.residual_fp=True \
training_schedule.losses.0.loss.partial_t_approx=True \
+training_schedule.losses.0.loss.single_gamma=True \
evaluation.eval_t=1e-5 \
wandb.enabled=True \
+wandb.name=minipeptide-baseline
pixi run train dataset=minipeptides \
training_schedule=vp_three_models \
+architecture=transformer/score_score_large_potential \
weighting_function=ranged \
training_schedule.epochs.0=10 \
training_schedule.epochs.1=20 \
training_schedule.epochs.2=100 \
training_schedule.BS=1024 \
checkpoint_options.save_interval_steps=10000 \
+training_schedule/augment=random_rotations \
evaluation.inference_bs=1024 \
evaluation.fp_inference_bs=256 \
"evaluation.limit_inference_peptides=[KA, RP]" \
evaluation.eval_t=1e-5 \
training_schedule.losses.2.loss.div_est=2 \
+training_schedule.losses.2.loss.residual_fp=True \
training_schedule.losses.2.loss.partial_t_approx=True \
+training_schedule.losses.2.loss.single_gamma=True \
training_schedule.losses.2.loss.beta=0.000 \
training_schedule.losses.2.time_weighting.midpoint=0.03 \
wandb.enabled=True \
+wandb.name=minipeptide-mixture
pixi run train dataset=minipeptides \
+architecture=transformer/large_potential \
training_schedule.epochs=120 \
training_schedule.BS=1024 \
checkpoint_options.save_interval_steps=10 \
+training_schedule/augment=random_rotations \
evaluation.inference_bs=1024 \
evaluation.fp_inference_bs=256 \
"evaluation.limit_inference_peptides=[KA, RP]" \
training_schedule.losses.0.loss.beta=0.0005 \
training_schedule.losses.0.loss.div_est=2 \
+training_schedule.losses.0.loss.residual_fp=True \
training_schedule.losses.0.loss.partial_t_approx=True \
+training_schedule.losses.0.loss.single_gamma=True \
wandb.enabled=True \
+wandb.name=minipeptide-fp-0.0005
pixi run train dataset=minipeptides \
training_schedule=vp_three_models \
+architecture=transformer/score_score_large_potential \
weighting_function=ranged \
training_schedule.epochs.0=10 \
training_schedule.epochs.1=20 \
training_schedule.epochs.2=100 \
training_schedule.BS=1024 \
checkpoint_options.save_interval_steps=10000 \
+training_schedule/augment=random_rotations \
evaluation.inference_bs=1024 \
evaluation.fp_inference_bs=256 \
"evaluation.limit_inference_peptides=[KA, RP]" \
evaluation.eval_t=1e-5 \
training_schedule.losses.2.loss.div_est=2 \
+training_schedule.losses.2.loss.residual_fp=True \
training_schedule.losses.2.loss.partial_t_approx=True \
+training_schedule.losses.2.loss.single_gamma=True \
training_schedule.losses.2.loss.beta=0.0001 \
training_schedule.losses.2.time_weighting.midpoint=0.03 \
wandb.enabled=True \
+wandb.name=minipeptide-both-fp-0.0001
For the minipeptides, the training script can also be used for evaluation on the test set. For that, set the validation set to the test set and use load_dir
.
pixi run train dataset=minipeptides \
+architecture=transformer/large_potential \
training_schedule.epochs=120 \
training_schedule.BS=1024 \
checkpoint_options.save_interval_steps=10 \
+training_schedule/augment=random_rotations \
evaluation.inference_bs=1024 \
evaluation.fp_inference_bs=256 \
training_schedule.losses.0.loss.beta=0.000 \
training_schedule.losses.0.loss.div_est=2 \
+training_schedule.losses.0.loss.residual_fp=True \
training_schedule.losses.0.loss.partial_t_approx=True \
+training_schedule.losses.0.loss.single_gamma=True \
evaluation.eval_t=1e-5 \
"dataset.val_path=./storage/minipeptides/test.npy" \
"evaluation.limit_inference_peptides=[KS, KG, AT, LW, KQ, NY, IM, TD, HT, NF, RL, ET, AC, RV, HP, AP]" \
evaluation.num_parallel_langevin_samples=10 \
evaluation.num_langevin_intermediate_steps=50 \
evaluation.num_langevin_samples=120000 \
evaluation.num_iid_samples=100000 \
load_dir=./outputs/minipeptides/DATE/TIME
wandb.enabled=True \
+wandb.name=minipeptide-test
Feel free to open an issue if you encounter any problems or have questions.
If you found our work useful please cite
@article{plainer2025consistent,
author = {Plainer, Michael and Wu, Hao and Klein, Leon and G{\"u}nnemann, Stephan and No{\'e}, Frank},
title = {Consistent Sampling and Simulation: Molecular Dynamics with Energy-Based Diffusion Models},
eprint = {arXiv:2506.17139},
year = {2025},
}