Skip to content

jamesdhope/VAE-examples

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Variational Autoencoder (VAE) for MNIST

This repository contains a PyTorch implementation of a Variational Autoencoder (VAE) for the MNIST dataset. The VAE is a generative model that learns to encode input data into a latent space and decode it back to the original space.

Table of Contents

Installation

To run this code, you need to have Python and PyTorch installed. You can install the required packages using pip:

pip install torch torchvision

Usage

  1. Clone this repository:

    git clone https://github.com/your-username/vae-mnist.git
    cd vae-mnist
  2. Run the script to train the VAE:

    python train_vae.py

Model Architecture

The VAE consists of two main parts: an encoder and a decoder.

Encoder

The encoder maps the input data to a latent space. It consists of:

  • A fully connected layer that transforms the input to a hidden dimension.
  • Two fully connected layers that output the mean (mu) and log variance (logvar) of the latent space.

Decoder

The decoder reconstructs the input data from the latent space. It consists of:

  • A fully connected layer that transforms the latent space to a hidden dimension.
  • A fully connected layer that outputs the reconstructed data.

VAE

The VAE combines the encoder and decoder. It also includes a reparameterization step to sample from the latent space.

Training

The training loop involves the following steps:

  1. Load the MNIST dataset.
  2. Initialize the VAE model and the optimizer.
  3. For each epoch:
    • For each batch of data:
      • Flatten the input data.
      • Perform a forward pass through the model.
      • Compute the loss (Binary Cross Entropy + Kullback-Leibler Divergence).
      • Perform backpropagation and update the model parameters.
    • Print the average loss for the epoch.

Saving the Model

After training, the model's state dictionary is saved to a file named vae.pth.

References

Feel free to contribute to this repository by opening issues or submitting pull requests.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages