This repository contains a PyTorch implementation of a Variational Autoencoder (VAE) for the MNIST dataset. The VAE is a generative model that learns to encode input data into a latent space and decode it back to the original space.
To run this code, you need to have Python and PyTorch installed. You can install the required packages using pip:
pip install torch torchvision
-
Clone this repository:
git clone https://github.com/your-username/vae-mnist.git cd vae-mnist
-
Run the script to train the VAE:
python train_vae.py
The VAE consists of two main parts: an encoder and a decoder.
The encoder maps the input data to a latent space. It consists of:
- A fully connected layer that transforms the input to a hidden dimension.
- Two fully connected layers that output the mean (
mu
) and log variance (logvar
) of the latent space.
The decoder reconstructs the input data from the latent space. It consists of:
- A fully connected layer that transforms the latent space to a hidden dimension.
- A fully connected layer that outputs the reconstructed data.
The VAE combines the encoder and decoder. It also includes a reparameterization step to sample from the latent space.
The training loop involves the following steps:
- Load the MNIST dataset.
- Initialize the VAE model and the optimizer.
- For each epoch:
- For each batch of data:
- Flatten the input data.
- Perform a forward pass through the model.
- Compute the loss (Binary Cross Entropy + Kullback-Leibler Divergence).
- Perform backpropagation and update the model parameters.
- Print the average loss for the epoch.
- For each batch of data:
After training, the model's state dictionary is saved to a file named vae.pth
.
- Kingma, Diederik P., and Max Welling. "Auto-encoding variational Bayes." arXiv preprint arXiv:1312.6114 (2013).
- PyTorch Documentation
Feel free to contribute to this repository by opening issues or submitting pull requests.