Skip to content

kuleshov-group/proseco

Repository files navigation

Learn from Your Mistakes: Self-Correcting Masked Diffusion Models

arXiv deploy deploy

graphical abstract

This repository contains code for reproducing experiments in the paper Learn from Your Mistakes: Self-Correcting Masked Diffusion Models

We also share trained models on HuggingFace 🤗 and support intergration with these models. See the "Using HuggingFace Models" section below.

Code Organization

  1. main.py: Routines for training (language models and classifiers)
  2. noise_schedule.py: Noise schedules
  3. diffusion.py: Forward/reverse diffusion
    • Absorbing state / uniform noise diffusion
    • AR
  4. dataloader.py: Dataloaders
  5. utils.py: LR scheduler, logging, fsspec handling
  6. models/: Denoising network architectures.
  7. configs/: Config files for datasets/denoising networks/noise schedules/LR schedules
  8. scripts/: Shell scripts for training/evaluation
  9. guidance_eval/: Guidance evaluation scripts
  10. llada/: Code to reproduce evaluation of LLaDA SFT models

ProSeCo Training

To enable ProSeCo training, set the corrector_training flag in config.yaml to True.

Additional parameters that can be tuned include the following:

corrector_training: True #
use_weighted_corrector_loss: True  # Whether to apply the αt' / 1 - αt weight to corrector loss
use_model_outputs_as_corrector_input: False  # Whether to pass denoiser outputs at all positions, or just masked ones
use_argmax_for_corrector: True   # Whether to use argmax sampling to create corrector inputs 
corrector_training_start_step: 0  # What (global) step to start applying corrector loss
mdlm_loss_weight: 1.0  # Additional optional weighting for MDLM loss
corrector_loss_weight: 1.0  # Additional optional weighting for corrector loss
corrector_loss_errors_upweighted: False  # Whether to prioritize mistakes in corrector loss (see Appendix C.3 for details)

ProSeCo Sampling

Below we detail the parameters one can use when applying corrector steps during inference. These parameters can be found under sampling in config.yaml:

corrector_prior_is_argmax: True  # Use argmax from denoiser as corrector input
corrector_sampling: 'argmax'  # Sampling scheme for corrector steps
corrector_every_n_steps: 1  # Frequency for applying corrector loops
corrector_steps: 0  # Max number of corrector steps per loop
corrector_start_iter: 0  # Can be used to delay when corrector steps are eligible to start
corrector_top_k: 0  # Used in conjunction with `select_top_k` strategy for corrector sampling

LLaDA experiments

We also provide code for reproducing the evaluations with our LLaDA-SFT model in the llada directory. See the README file there for more details, and download the model from HuggingFace.

Getting started in this repository

To get started, create a conda environment containing the required dependencies.

conda env create -f requirements.yaml
conda activate discdiff

Create the following directories to store saved models and slurm logs:

mkdir outputs
mkdir watch_folder

We rely on wandb integration to log experiments and eval curves.

Reproducing Experiments

Throughout, the main entry point for running experiments is the main.py script. We also provide sample slurm scripts for launching pre-training and evaluation experiments in the scrips/ directory.

Using HuggingFace Models

We provide pre-trained models on HuggingFace 🤗:

Please see the README pages for these models on HuggingFace or our paper for more details about the training of these models.

To use these models, you can load them using the HuggingFace API, e.g.,

from transformers import AutoModelForCausalLM, AutoModelForMaskedLM

model = AutoModelForCausalLM.from_pretrained("kuleshov-group/proseco-llada-sft")
model = AutoModelForMaskedLM.from_pretrained("kuleshov-group/proseco-owt")

To use these models in our repository, set the following config parameters:

backbone="hf_dit"
model="hf"
model.pretrained_model_name_or_path="kuleshov-group/proseco-owt"

Acknowledgements

This repository was built off of UDLM and MDLM.

Citation

@article{schiff2026learn,
  title={Learn from Your Mistakes: Self-Correcting Masked Diffusion Models},
  author={Schiff, Yair and Belhasin, Omer and Uziel, Roy and Wang, Guanghan and Arriola, Marianne and Turok, Gilad and Elad, Michael and Kuleshov, Volodymyr},
  journal={arXiv preprint arXiv:2602.11590},
  year={2026}
}

About

Learn from Your Mistakes: Self-Correcting Masked Diffusion Models

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors