Skip to content

Commit

Permalink
update trash detection folder
Browse files Browse the repository at this point in the history
  • Loading branch information
YaelBenShalom committed Mar 21, 2021
1 parent e75b264 commit 20d982a
Show file tree
Hide file tree
Showing 17 changed files with 1,429 additions and 0 deletions.
63 changes: 63 additions & 0 deletions trash_detection/trash_classification/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Trash Classification
## ME499 - Independent Project, Winter 2021
Yael Ben Shalom, Northwestern University.<br>
This module is a part of a [Traffic-Sign Recognition and Classification](https://github.com/YaelBenShalom/Traffic-Sign-Recognition-and-Classification) project.


## Module Overview
In this module I built and trained a neural network to classify different recycable objects using PyTorch.<br>
I based my program on the [Garbage Classification Dataset](https://www.kaggle.com/asdasdasasdas/garbage-classification).


## User Guide
### Program Installation

1. Clone the repository, using the following commands:
```
git clone https://github.com/YaelBenShalom/Traffic-Sign-Recognition-and-Classification/tree/master/trash_recognition/trash_classification
```
2. Extract the dataset located in the `./data` directory.
### Quickstart Guide
Run the classification program:
1. To train and test the program on the dataset, run the following command from the root directory:
```
python code/classification.py
```
2. To train the program on the dataset and test it on a specific image, copy the image to the root directory and run the following command from the root directory:
```
python code/classification.py --image <image-name>
```
Where `<image-name>` is the name of the image (including image type).<br>
The trained model will be saved in the root directory as `/model`.
3. To to use an existing model and test it on a specific image, copy the image to the root directory and run the following command from the root directory:
```
python code/classification.py --image <image-name> --model <model-name>
```
Where `<model-name>` is the name of the trained model.
<br>The program output when running it on the example image:
The loss plot:<br>
![Loss Graph](https://github.com/YaelBenShalom/Traffic-Sign-Recognition-and-Classification/blob/master/trash_recognition/trash_classification/images/Losses%20(100%20Epochs).png)
The accuracy plot:<br>
![Accuracy Graph](https://github.com/YaelBenShalom/Traffic-Sign-Recognition-and-Classification/blob/master/trash_recognition/trash_classification/images/Accuracy%20(100%20Epochs).png)
The output image (with the correct prediction):<br>
![Accuracy Graph](https://github.com/YaelBenShalom/Traffic-Sign-Recognition-and-Classification/blob/master/trash_recognition/trash_classification/images/Image_Classification.png)
## Dataset
The Garbage Classification Dataset contains 6 classifications: cardboard (393), glass (491), metal (400), paper(584), plastic (472), and trash(127).
The dataset can be found on the [Garbage Classification](https://www.kaggle.com/asdasdasasdas/garbage-classification) page on Kaggle, or downloaded directly through [here](https://www.kaggle.com/asdasdasasdas/garbage-classification/download).
282 changes: 282 additions & 0 deletions trash_detection/trash_classification/code/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
import argparse
import os
import numpy as np
import random
import cv2
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import random_split
from torchvision import models

from load_data import load_dataset
from read_data import ReadDataset
from run_model import run_model, predict
from cnn import BaselineNet, ResNet


def load_arguments(args):
"""
This function loads the image argument.
If an image was added as argument, the program will predict its class after training and testing the model.
If no image was added as argument, the program will only train and test the results on the datasets.
Inputs:
args: the input arguments of the program in the form of a dictionary {"image" : <argument>}.
if args exist, <argument> is the input image, else <argument> is None.
Output:
test_image: the input image that should be tested by the model.
"""
test_image = plt.imread(args["image"])
plt.figure()
plt.imshow(test_image)
plt.show()

return test_image


def dataset_properties(trainset_name, validset_name, testset_name, class_names, data_dir):
"""
This function finds the dataset properties.
This function is for information only.
Inputs:
trainset_name: the name of the training set file.
validset_name: the name of the validation set file.
testset_name: the name of the testing set file.
Output: None
"""
train_features, train_labels = load_dataset(trainset_name, base_folder=data_dir)
valid_features, valid_labels = load_dataset(validset_name, base_folder=data_dir)
test_features, test_labels = load_dataset(testset_name, base_folder=data_dir)

# print(f"train_features shape: {train_features.shape}")
# print(f"train_labels shape: {train_labels.shape}")
print(f"train dataset size: {len(train_features)}")

# print(f"valid_features: {valid_features.shape}")
# print(f"valid_labels: {valid_labels.shape}")
print(f"validation dataset size: {len(valid_features)}")

# print(f"test_features: {test_features.shape}")
# print(f"test_labels: {test_labels.shape}")
print(f"test dataset size: {len(test_labels)}")

# Finding the number of classes in the dataset
classes_num = len(set(train_labels))
print(f"Number of classes: {classes_num}")

# Finding the size of the images in the dataset
image_shape = train_features[0].shape[:2]
print(f"images shape: {image_shape}")

# Plotting class distribution for training set
fig, ax = plt.subplots()
ax.bar(range(classes_num), np.bincount(train_labels))
ax.set_title('Class Distribution in the Train Set', fontsize=20)
ax.set_xlabel('Class Number')
ax.set_ylabel('Number of Events')
plt.savefig('images/Class_Distribution.png')
plt.show()

# Plotting random 40 images from train set
plt.figure(figsize=(12, 12))
for i in range(40):
feature_index = random.randint(0, train_labels.shape[0])
plt.subplot(6, 8, i + 1)
plt.subplots_adjust(left=0.1, bottom=0.03, right=0.9, top=0.92, wspace=0.2, hspace=0.2)
plt.axis('off')
plt.imshow(train_features[feature_index])
plt.suptitle('Random Training Images', fontsize=20)
plt.savefig('images/Random_Training_Images.png')
plt.show()

# Plotting images for every class from train set
plt.figure(figsize=(14, 14))
for i in range(classes_num):
feature_index = random.choice(np.where(train_labels == i)[0])
plt.subplot(6, 8, i + 1)
plt.subplots_adjust(left=0.1, bottom=0.03, right=0.9, top=0.92, wspace=0.2, hspace=0.2)
plt.axis('off')
plt.title(class_names[i], fontsize=10)
plt.imshow(train_features[feature_index])
plt.suptitle('Random training images from different classes', fontsize=20)
plt.savefig('images/Random_Training_Images_Different_Class.png')
plt.show()


def class_names_fun(data_dir):
"""
This function returns a dictionary with the classes numbers and names.
Inputs: None
Output:
class_names: a dictionary with the classes numbers and names.
"""
# Class names
classes = os.listdir(data_dir)

# Defining class names dictionary
class_names = {}
for i in range(len(classes)):
class_names[i] = classes[i]
print("class_names: ", class_names)
return class_names


def plot_training_results(train_loss_list, valid_loss_list, valid_accuracy_list, epoch_num):
"""
This function plots the results of training the network.
Inputs:
train_loss_list: list of loss value on the entire training dataset.
valid_loss_list: list of loss value on the entire validation dataset.
valid_accuracy_list: list of accuracy on the entire validation dataset.
Output: None
"""
# Plotting training and validation loss vs. epoch number
plt.figure()
plt.plot(range(len(train_loss_list)), train_loss_list, label='Training Loss')
plt.plot(range(len(valid_loss_list)), valid_loss_list, label='Validation Loss')
plt.title(f'Training and Validation Loss Vs. Epoch Number ({epoch_num} Epochs)')
plt.xlabel('Epoch Number')
plt.ylabel('Loss')
plt.legend(loc="best")
plt.savefig(f"images/Losses ({epoch_num} Epochs).png")
plt.show()

# Plotting validation accuracy vs. epoch number
plt.figure()
plt.plot(range(len(valid_accuracy_list)), valid_accuracy_list, label='Validation Accuracy')
plt.title(f'Validation Accuracy Vs. Epoch Number ({epoch_num} Epochs)')
plt.xlabel('Epoch Number')
plt.ylabel('Accuracy')
plt.xlim([0, len(train_loss_list)])
plt.ylim([0, 100])
plt.legend(loc="best")
plt.savefig(f"images/Accuracy ({epoch_num} Epochs).png")
plt.show()


def main(args):
""" Main function of the program
Inputs:
args: the input arguments of the program in the form of a dictionary {"image" : <argument>}.
if args exist, <argument> is the input image, else <argument> is None.
Output: None
"""

# Define dataset directory
data_dir = "data/Garbage classification"

# Finding dataset properties
class_names = class_names_fun(data_dir)

# Define the device parameters
torch.manual_seed(17)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

# Define the model
model = ResNet()

# Define the training properties
epoch_num = 100
criterion = nn.CrossEntropyLoss()
learning_rate = 5.5e-5
batch_size = 32
stop_threshold = 1e-4

# Computing data transformation to normalize data
mean = (0.485, 0.456, 0.406) # from https://pytorch.org/docs/stable/torchvision/transforms.html
std = (0.229, 0.224, 0.225) # -"-
transform = transforms.Compose([transforms.ToTensor(),
transforms.Resize((32, 32)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(degrees=15),
transforms.CenterCrop(size=224),
transforms.Normalize(mean=mean, std=std)])

dataset = ImageFolder(data_dir, transform=transform)

# Split the dataset to 3 groups - train, validation, test
dataset_len = len(dataset)
dataset_ratio = [0.64, 0.26 , 0.1]
dataset_split = [round(element * dataset_len) for element in dataset_ratio]
train_dataset, valid_dataset, test_dataset = random_split(dataset, dataset_split)

# If no input model - training a new model
if not args["model"]:
# Defining the model
model_path = os.path.abspath("model")

# Train the network
model, train_loss_list, valid_loss_list, valid_accuracy_list = run_model(model, running_mode='train',
train_set=train_dataset,
valid_set=valid_dataset,
test_set=test_dataset,
batch_size=batch_size, epoch_num=epoch_num,
learning_rate=learning_rate,
stop_thr=stop_threshold,
criterion=criterion, device=device)
# Plot the results of training the network
plot_training_results(train_loss_list, valid_loss_list, valid_accuracy_list, epoch_num)

# Save the trained model
torch.save(model.state_dict(), model_path)

# If input model - load the existing model
else:
# Defining the model
model_path = os.path.abspath(args["model"])

# Load the trained model
model.load_state_dict(torch.load(model_path, map_location=device))

# Test the network
test_loss, test_accuracy = run_model(model, running_mode='test', train_set=train_dataset,
valid_set=valid_dataset, test_set=test_dataset,
batch_size=batch_size, epoch_num=epoch_num,
learning_rate=learning_rate, stop_thr=stop_threshold,
criterion=criterion, device=device)

print(f"Test loss: {test_loss:.3f}")
print(f"Test accuracy: {test_accuracy:.2f}%")

# Check if image argument exists
if args["image"]:
# Load the image argument
test_image = load_arguments(args)
test_image_resized = cv2.resize(test_image, (32, 32))

test_image_tensor = transforms.ToTensor()(np.array(test_image_resized))

# Transform tested image
test_image_transform4d = test_image_tensor.unsqueeze(0)

# Predict the class of the tested image
prediction = int(predict(model, test_image_transform4d)[0])
print(f"Test prediction: {prediction} -> Class: {class_names[prediction]}")

# Plot the image with the predicted class
plt.figure()
plt.axis('off')
plt.title(class_names[prediction], fontsize=10)
plt.imshow(test_image)
plt.suptitle('Image Classification', fontsize=18)
plt.savefig('images/Image_Classification')
plt.show()


if __name__ == "__main__":
# construct the argument parse and parse the arguments
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--image", help="path to the input image")
parser.add_argument("-m", "--model", help="path to the input image")
args = vars(parser.parse_args())
main(args)
Loading

0 comments on commit 20d982a

Please sign in to comment.