Skip to content

Mahmoud-Da/Lightweight-CNN-Pytorch

Repository files navigation

Lightweight-CNN-Pytorch

TinyVGG Image Classification Project

This project implements a TinyVGG model for image classification, based on the architecture from the CNN Explainer website. The code is structured for modularity, allowing for easy configuration and execution of training and prediction tasks.

Project Structure


.
├── config.py # All configurations (hyperparameters, paths, etc.)
├── data_setup.py # For creating PyTorch DataLoaders
├── model.py # TinyVGG model definition (ensure your TinyVGG class is here)
├── engine.py # Training loop (train_step, test_step, train_model functions)
├── utils.py # Utility functions (e.g., save_model)
├── train.py # Main script to run model training
├── predict.py # Script to make predictions on new images
├── data/ # Directory for your image datasets
│ └── pizza_steak_sushi/ (Example dataset name)
│ ├── train/ # Training images
│ │ ├── class1/ (e.g., pizza)
│ │ ├── class2/ (e.g., steak)
│ │ └── ...
│ └── test/ # Testing images
│ ├── class1/
│ ├── class2/
│ └── ...
├── models/ # Directory where trained models are saved
└── README.md # This file

Features

  • Modular Design: Code is separated into logical modules for data setup, model building, training engine, and utilities.
  • Configurable: Most parameters, including hyperparameters, file paths, and model settings, are managed through config.py.
  • TinyVGG Implementation: A PyTorch implementation of the TinyVGG convolutional neural network.
  • Training Script: train.py handles the complete training pipeline, including data loading, model training, and saving the trained model.
  • Prediction Script: predict.py loads a trained model and makes predictions on a specified image.
  • Device Agnostic: Code attempts to use a CUDA-enabled GPU if available, otherwise defaults to CPU.

Prerequisites

  • Python 3.7+
  • PyTorch
  • TorchVision
  • Pillow (PIL)
  • tqdm (for progress bars)

You can install the necessary Python packages using pip:

pip install torch torchvision pillow tqdm

Setup

  1. Clone the repository (if applicable) or create the project files: Ensure all the .py files (config.py, data_setup.py, model.py, engine.py, utils.py, train.py, predict.py) are in the root directory of your project.

  2. Create Data Directories:

    mkdir -p data/pizza_steak_sushi/train
    mkdir -p data/pizza_steak_sushi/test
    mkdir -p models

    Replace pizza_steak_sushi with your dataset's name if different.

  3. Prepare Your Dataset:

    • Place your training images in the data/YOUR_DATASET_NAME/train/ directory, organized into subfolders named after each class (e.g., data/pizza_steak_sushi/train/pizza/, data/pizza_steak_sushi/train/steak/).
    • Place your testing images in the data/YOUR_DATASET_NAME/test/ directory, with the same class-based subfolder structure.
  4. Configure config.py: Open config.py and review/adjust the settings:

    • Data Settings:
      • TRAIN_DIR: Path to your training data.
      • TEST_DIR: Path to your testing data.
      • IMAGE_SIZE: Target image size for resizing (default: (64, 64)).
      • BATCH_SIZE: Batch size for training and testing.
    • Model Parameters:
      • INPUT_SHAPE: Number of input channels (e.g., 3 for RGB).
      • HIDDEN_UNITS: Number of hidden units in convolutional layers.
    • Training Hyperparameters:
      • NUM_EPOCHS: Number of training epochs.
      • LEARNING_RATE: Learning rate for the optimizer.
    • Model Saving:
      • MODEL_SAVE_DIR: Directory to save trained models.
      • MODEL_NAME_PREFIX: Prefix for saved model filenames.
    • Prediction Settings (for predict.py):
      • IMAGE_PATH_FOR_PREDICTION: Path to the image you want to predict.
      • MODEL_PATH_FOR_PREDICTION: Path to the trained .pth model file to use for prediction.
      • CLASS_NAMES_FOR_PREDICTION: Crucially, list your class names here in the exact order the model was trained on (this order is usually determined alphabetically by ImageFolder or how you set it up).

Training the Model

  1. Ensure your dataset is prepared and config.py (especially TRAIN_DIR, TEST_DIR, and training hyperparameters) is correctly configured.
  2. Run the training script from the project's root directory:
    python train.py
  3. The script will:
    • Load and preprocess the data.
    • Initialize the TinyVGG model.
    • Train the model for the specified number of epochs, printing progress and metrics.
    • Save the trained model's state_dict to the directory specified by MODEL_SAVE_DIR in config.py (e.g., models/tinyvgg_model_v1.pth).

Making Predictions

  1. Ensure you have a trained model: A .pth file should exist in your models/ directory (or the path specified in MODEL_PATH_FOR_PREDICTION).
  2. Configure predict.py via config.py:
    • Open config.py.
    • Set IMAGE_PATH_FOR_PREDICTION to the full path of the new image you want to classify.
    • Set MODEL_PATH_FOR_PREDICTION to the path of your trained model file (e.g., models/tinyvgg_model_v11.pth).
    • Verify CLASS_NAMES_FOR_PREDICTION: This list must match the order of classes the model was trained on. For example, if your training data subfolders were cat, dog, bird (and ImageFolder loaded them in that order), then CLASS_NAMES_FOR_PREDICTION should be ["cat", "dog", "bird"].
  3. Run the prediction script:
    python predict.py
  4. The script will output the predicted class label and the confidence score for the specified image.

Customization

  • Dataset: To use a different dataset, update TRAIN_DIR and TEST_DIR in config.py and ensure your data is structured correctly in the data/ directory. Also, update CLASS_NAMES_FOR_PREDICTION for prediction.
  • Model Architecture: Modify the TinyVGG class in model.py to change the network architecture. Remember to adjust INPUT_SHAPE, HIDDEN_UNITS, or how OUTPUT_SHAPE is determined if you make significant changes. The in_features for the final nn.Linear layer in TinyVGG is dependent on the output size of the convolutional blocks and the input IMAGE_SIZE. You may need to recalculate this if you change the architecture or IMAGE_SIZE.
  • Hyperparameters: Adjust BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE, etc., in config.py to experiment with different training settings.
  • Image Transformations: Modify the transforms.Compose([...]) sections in data_setup.py (for training) and predict.py (for prediction) if you need different image preprocessing steps (e.g., data augmentation, normalization). Ensure prediction transforms match training transforms.

Potential Issues & Troubleshooting

  • FileNotFoundError: Double-check all paths in config.py (TRAIN_DIR, TEST_DIR, IMAGE_PATH_FOR_PREDICTION, MODEL_PATH_FOR_PREDICTION). Ensure the files/directories exist.
  • Incorrect Predictions/Low Accuracy:
    • Verify that CLASS_NAMES_FOR_PREDICTION in config.py exactly matches the order of classes the model was trained on.
    • Ensure image transformations in predict.py are identical to those used during training (especially IMAGE_SIZE and normalization if used).
    • The model might need more training (more epochs, larger dataset) or hyperparameter tuning.
    • The in_features for the classifier's nn.Linear layer in model.py might be incorrect if IMAGE_SIZE or the convolutional architecture has changed.
  • CUDA Errors: Ensure PyTorch was installed with CUDA support if you have a compatible NVIDIA GPU. If not, the code should fall back to CPU.
  • RuntimeError: Mismatch in shape...: This often happens if the OUTPUT_SHAPE of the model (derived from len(CLASS_NAMES_FOR_PREDICTION) or number of classes in training data) doesn't match what the loaded model expects, or if the in_features of the classifier layer is wrong.

Contributing

Feel free to fork this project and submit pull requests for improvements or bug fixes.


Remember to replace placeholders like YOUR_DATASET_NAME and ensure the paths and class names in config.py are accurate for your specific setup.


**How to use this README:**

1.  Save the content above into a file named `README.md` in the root directory of your project.
2.  Review it carefully and **customize** it:
    *   If your main dataset folder is not `pizza_steak_sushi`, change the example paths.
    *   If you've named `model_builder.py` as `model.py` (as assumed in recent steps), ensure the README reflects that.
    *   Add any specific notes about your dataset or model variations.
3.  When you share your project (e.g., on GitHub), this `README.md` will be automatically displayed, providing a good overview for others (and your future self!).

Running Your PyTorch Project with Docker

This document outlines the steps to build and run this PyTorch application using Docker and Docker Compose. This ensures a consistent and reproducible environment for development and deployment.

Prerequisites

  1. Docker: Ensure Docker Desktop (for Mac/Windows) or Docker Engine (for Linux) is installed and running. You can download it from docker.com.
  2. Docker Compose: Docker Compose V2 is typically included with Docker Desktop. For Linux, you might need to install it separately.
  3. (Optional) NVIDIA GPU Support:
    • If you intend to use NVIDIA GPUs, ensure you have the latest NVIDIA drivers installed on your host machine.
    • Install the NVIDIA Container Toolkit on your host machine. This allows Docker containers to access NVIDIA GPUs.
  4. Project Files:
    • Dockerfile: Defines the Docker image for the application.
    • docker-compose.yml: Defines how to run the application services (including GPU support).
    • Pipfile: Specifies Python package dependencies.
    • Pipfile.lock: Locks package versions for reproducible builds.
    • Your application code (e.g., inference.py).

Building and Running the Application

We will use Docker Compose to manage the build and run process.

Step 1: Clone the Repository (if applicable)

If you haven't already, clone the project repository to your local machine:

git clone <your-repository-url>
cd <your-project-directory>

Step 2: Check/Generate Pipfile.lock

The Dockerfile uses pipenv install --deploy, which requires Pipfile.lock to be up-to-date with Pipfile.

Troubleshooting Pipfile.lock out-of-date error: If, during the Docker build process (Step 3), you encounter an error similar to:

Your Pipfile.lock (...) is out of date. Expected: (...).
ERROR:: Aborting deploy

This means your Pipfile.lock is not synchronized with your Pipfile. To fix this, run the following command in your project's root directory (where Pipfile is located) on your host machine:

pipenv lock

This will update Pipfile.lock. After running this command, proceed to Step 3.

Step 3: Build and Run with Docker Compose

Open your terminal in the root directory of the project (where docker-compose.yml and Dockerfile are located).

To build the image and run the application (e.g., execute inference.py):

docker-compose up --build
  • --build: This flag tells Docker Compose to build the Docker image using the Dockerfile. You can omit this on subsequent runs if the Dockerfile or its dependencies haven't changed, and an image already exists.
  • The application (defined by CMD in the Dockerfile, e.g., python3 inference.py) will start, and its output will be displayed in your terminal.

To run in detached mode (in the background):

docker-compose up --build -d

Step 4: Interacting with the Application

  • Viewing Logs (if running in detached mode):

    docker-compose logs -f app

    (Replace app with your service name if it's different in docker-compose.yml). Press Ctrl+C to stop following logs.

  • Accessing a Shell Inside the Container (for debugging): If you need to explore the container's environment or run commands manually:

    1. Ensure the container is running (e.g., using docker-compose up -d).
    2. Open a shell:
      docker-compose exec app bash
      (Replace app with your service name if it's different).
    3. Inside the container, you can navigate to /app (the working directory) and run Python scripts or other commands.
  • Port Mapping (if applicable): If your application (inference.py) runs a web server (e.g., on port 8000) and you have configured port mapping in docker-compose.yml (e.g., ports: - "8000:8000"), you can access it via http://localhost:8000 in your web browser.

Step 5: Stopping the Application

To stop and remove the containers, networks, and (optionally, depending on docker-compose down flags) volumes defined by Docker Compose:

docker-compose down

If you want to remove the volumes as well:

docker-compose down -v

Important Notes

  • PyTorch Versions & CUDA: The Pipfile specifies PyTorch versions and a CUDA source (pytorch-cu111). Ensure these versions are valid and available from the specified PyTorch wheel index. If pipenv install fails during the Docker build due to version conflicts or "Could not find a version" errors, you will need to:
    1. Consult PyTorch Previous Versions to find compatible torch, torchvision, and torchaudio versions for your desired CUDA version (e.g., CUDA 11.1).
    2. Update the versions in your Pipfile.
    3. Run pipenv lock locally to regenerate Pipfile.lock.
    4. Re-run docker-compose up --build.
  • GPU Usage: The docker-compose.yml is configured to attempt GPU access using NVIDIA. This requires the prerequisites mentioned above (NVIDIA drivers and NVIDIA Container Toolkit on the host). If GPUs are not available or not configured correctly, PyTorch will typically fall back to CPU mode.
  • Development Mode Volume Mount: The docker-compose.yml includes volumes: - .:/app. This mounts your local project directory into the container. Code changes made locally will be reflected inside the container, which is useful for development. For production, you might remove this volume mount to rely solely on the code baked into the image.

Further Actions

  • Cleaning up Docker Resources:
    • To remove unused Docker images: docker image prune
    • To remove unused Docker volumes: docker volume prune
    • To remove unused Docker networks: docker network prune
    • To remove all unused Docker resources (images, containers, volumes, networks): docker system prune -a (Use with caution!)

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published