This repository contains code to reproduce results from our paper on using sparse autoencoders (SAEs) to analyze and interpret the internal representations of text-to-image diffusion models, specifically SDXL Turbo.
|-- SAE/ # Core sparse autoencoder implementation
|-- SDLens/ # Tools for analyzing diffusion models
| `-- hooked_sd_pipeline.py # Modified stable diffusion pipeline
|-- scripts/
| |-- collect_latents_dataset.py # Generate training data
| `-- train_sae.py # Train SAE models
|-- utils/
| `-- hooks.py # Hook utility functions
|-- checkpoints/ # Pretrained SAE model checkpoints
|-- app.py # Demo application
|-- app.ipynb # Interactive notebook demo
|-- example.ipynb # Usage examples
`-- requirements.txt # Python dependencies
pip install -r requirements.txt
You can try our gradio demo application (app.ipynb
) to browse and experiment with 20K+ features of our trained SAEs out-of-the-box. You can find the same notebook on Google Colab.
- Collect latent data from SDXL Turbo:
python scripts/collect_latents_dataset.py --save_path={your_save_path}
-
Train sparse autoencoders:
2.1. Insert the path of stored latents and directory to store checkpoints in
SAE/config.json
2.2. Run the training script:
python scripts/train_sae.py
We provide pretrained SAE checkpoints for 4 key transformer blocks in SDXL Turbo's U-Net in the checkpoints
folder. See example.ipynb
for analysis examples and visualization of learned features. More pretrained SAEs with different parameters are accessible through HuggingFace repo.
If you find this code useful in your research, please cite our paper:
@misc{surkov2024unpackingsdxlturbointerpreting,
title={Unpacking SDXL Turbo: Interpreting Text-to-Image Models with Sparse Autoencoders},
author={Viacheslav Surkov and Chris Wendler and Mikhail Terekhov and Justin Deschenaux and Robert West and Caglar Gulcehre},
year={2024},
eprint={2410.22366},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2410.22366},
}
The SAE component was implemented based on openai/sparse_autoencoder
repository.