diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0459a5d --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +model/* +!model/.gitkeep + +data/* +!data/*.zip + +__pycache__ diff --git a/README.md b/README.md new file mode 100644 index 0000000..74d48e9 --- /dev/null +++ b/README.md @@ -0,0 +1,31 @@ +# FINNger + +FINNger is a CNN intended to detect how many raised fingers you have through your webcam (or any image capturing device, actually). The final intention from this work is to create a mobile app where children can learn some basic arithmetics. This is mostly the code for the model generation, with a small proof-of-concept to check that it would really work. + +More information about the work per se and the model can be found on the paper: _link unavailable at the moment_ + +## Installing + +We need `Python3` to run this code. To install our library dependencies you can run `pip3 install -r requirements.txt`. + +## Dataset + +By default, we already have one of the used datasets available on this repository. You can run, from the root the command `./extract_dataset.sh` and the custom dataset will be available on `model/`. + +To download [koriakinp/fingers](https://www.kaggle.com/koryakinp/fingers) repository, refer to Kaggle website to understand how you can download the dataset. + +## Model + +A trained model is not available in the repository. However, on the releases tab we made the final model and optimizer state available for demonstration purposes. + +## Results + +As stated above, the full results can be found on the paper. However, here we have a small demonstration of the high accuracy of the trained model on the validation images. On this image, the row is the expected value, and the columns is the FINNger model output value. + +![Correlation Matrix for our model](images/nn_detection_corr.png) + + +## Authors + +- [Rafael Baldasso Audibert](https://www.rafaaudibert.dev) +- Vinicius Maschio \ No newline at end of file diff --git a/calculator.py b/calculator.py new file mode 100644 index 0000000..41a6042 --- /dev/null +++ b/calculator.py @@ -0,0 +1,27 @@ +class Calculator(): + def __init__(self): + self.a = None + self.b = None + + def add_number(self, number): + if self.a is None: + self.a = number + elif self.b is None: + self.b = number + else: + self.a = None + self.b = None + + @property + def result(self): + if self.a is None or self.b is None: + return None + + return self.a + self.b + + def __str__(self): + a = "?" if self.a is None else str(self.a) + b = "?" if self.b is None else str(self.b) + c = "?" if self.a is None or self.b is None else str(self.result) + + return f"{a} + {b} = {c}" diff --git a/data/custom_fingers.zip b/data/custom_fingers.zip new file mode 100644 index 0000000..00bac70 Binary files /dev/null and b/data/custom_fingers.zip differ diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..dfbbdcd --- /dev/null +++ b/dataset.py @@ -0,0 +1,49 @@ +import torch +import numpy as np +from torch.utils.data import Dataset +import glob +import cv2 +from tqdm import tqdm + + +def Identity(x): return x + + +class FINNgerDataset(Dataset): + """Hand Images dataset available at https://www.kaggle.com/koryakinp/fingers.""" + + NUM_CLASSES = 6 + + def __init__(self, data_dir, transform=Identity): + """ + Args: + data_dir (string): Directory with all the images. + transform (callable, optional): Optional transform to be applied on a sample, by default is Identity. + """ + self.data_dir = data_dir + self.transform = transform + + self.glob_path = glob.glob(data_dir) + self.dataset = [] + for img_path in tqdm(self.glob_path, desc="Import data"): + # Images are in the format _.png and here we are parsing the number from the class characters + image_label = int(img_path[-6:-5]) + + image = cv2.imread(img_path) + + self.dataset.append({'image': image, 'label': image_label}) + self.dataset = np.array(self.dataset) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + + # Return in a good format for testing + sample = self.dataset[idx] + return ( + self.transform(sample['image']), + sample['label'], + ) diff --git a/dataset_generator.py b/dataset_generator.py new file mode 100644 index 0000000..049d153 --- /dev/null +++ b/dataset_generator.py @@ -0,0 +1,87 @@ +from threading import Thread +import uuid + +import cv2 +import PySimpleGUI as sg +import click + + +key_pressed: bool = False +pictures_taken = 0 + + +def detect_key_press(): + global key_pressed + didnt_press_before = False + + while True: + input("Press anything to take a screenshot") + key_pressed = True + if didnt_press_before: + print("Thank you!") + didnt_press_before = True + + +def save_image(image, identifier: str, path: str, size: int) -> None: + global pictures_taken + + generated_uuid = uuid.uuid4() + full_name = f"{generated_uuid}_{identifier}.png" + + resized_array = cv2.resize(image, (size, size)) + black_and_white = cv2.cvtColor(resized_array, cv2.COLOR_RGB2GRAY) + + full_path = path + full_name + cv2.imwrite(full_path, black_and_white) + + pictures_taken += 1 + print(f"Saved image {pictures_taken} to {full_path}") + + print("Shape read back is", cv2.imread(full_path).shape) + + +@click.command() +@click.option("--identifier", required=True, help="The identifier appended to the end of the image") +@click.option("--path", required=True, help="The path where the photos will be saved on") +@click.option("--size", default=128, help="Size to resize the image to", show_default=True) +@click.option("--n_images", default=float("inf"), help="How many images we should generate", show_default=True) +def main(identifier: str, path: str, size: int, n_images: int): + global key_pressed + + # Thread used to detect the key pressing + thread = Thread(target=detect_key_press) + thread.start() + + window = sg.Window( + 'Dataset Generator', + [[sg.Image(filename='', key='image')], ], + location=(800, 400), + ) + + cap = cv2.VideoCapture(0) # Setup the camera as a capture device + while True: + # get events for the window with 20ms max wait + event, _values = window.Read(timeout=20, timeout_key='timeout') + if event is None: # if user closed window, quit + break + + _ret, image = cap.read() + + # Update image in window + window_image = window.FindElement('image') + encoded_image = cv2.imencode('.png', image)[1].tobytes() + window_image.Update(data=encoded_image) + + # This is handled in a different thread, responsible for detecting the key press + if key_pressed: + save_image(image, identifier, path, size) + key_pressed = False + + if pictures_taken >= n_images: + print( + f"Finished generating {pictures_taken} images. Quitting application...") + break + + +if __name__ == "__main__": + main() diff --git a/default_config.py b/default_config.py new file mode 100644 index 0000000..1eab5f9 --- /dev/null +++ b/default_config.py @@ -0,0 +1,7 @@ +DEFAULT_LEARNING_RATE = 0.0003 +DEFAULT_WEIGHT_DECAY = 1e-4 + +DEFAULT_TRAIN_DATASET = "data/fingers/train/*.png" +DEFAULT_TEST_DATASET = "data/fingers/test/*.png" + +DEFAULT_BATCH_SIZE = 8 diff --git a/extract_dataset.sh b/extract_dataset.sh new file mode 100755 index 0000000..b0f3113 --- /dev/null +++ b/extract_dataset.sh @@ -0,0 +1,2 @@ +unzip data/fingers.zip -d data +unzip data/custom_fingers.zip -d data diff --git a/images/nn_detection_corr.png b/images/nn_detection_corr.png new file mode 100644 index 0000000..3185d18 Binary files /dev/null and b/images/nn_detection_corr.png differ diff --git a/main.py b/main.py new file mode 100644 index 0000000..df45ae5 --- /dev/null +++ b/main.py @@ -0,0 +1,267 @@ +from threading import Thread +import time + +import torch +from torchvision import transforms +from torch.utils.data import DataLoader +import torch.nn.functional as F +import numpy as np +import cv2 +import PySimpleGUI as sg +import click +from tqdm import tqdm +import matplotlib.pyplot as plt + + +from model import FINNger +from dataset import FINNgerDataset +from calculator import Calculator + + +from default_config import ( + DEFAULT_BATCH_SIZE, + DEFAULT_LEARNING_RATE, + DEFAULT_TEST_DATASET, + DEFAULT_TRAIN_DATASET, + DEFAULT_WEIGHT_DECAY, +) + + +key_pressed: bool = False + + +def detect_key_press(): + global key_pressed + didnt_press_before = False + + while True: + input("Press anything to take a screenshot and use this number in the calculator") + key_pressed = True + if didnt_press_before: + print("Thank you!") + didnt_press_before = True + + +def add_text(image, text, position, scale=1): + cv2.putText( + image, + text, + position, + cv2.FONT_HERSHEY_SIMPLEX, + scale, + (255, 255, 255), + 2, + ) + + +def test(network, test_loader, epoch='', should_save_test_acc=False): + network.eval() + test_loss = 0 + correct = 0 + + if should_save_test_acc: + f = open(f"data/input_{epoch}.txt", "w") + + with torch.no_grad(): + for data, target in tqdm(test_loader, desc="Test validation"): + output = network(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.data.max(1, keepdim=True)[1] + correct += pred.eq(target.data.view_as(pred)).sum() + + if should_save_test_acc: + for tgt, prd in zip(target.data.view_as(pred), pred): + f.write(f"{tgt[0]} {prd[0]}\n") + + test_loss /= len(test_loader.dataset) + print( + f'\nTest set: Avg. loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({(100. * correct / len(test_loader.dataset)):.0f}%)\n') + + # If we opened the file, we close it here + if should_save_test_acc: + f.close() + + return test_loss + + +def train_epoch(epoch_number, network, train_loader, save_id=''): + losses, counter = [], [] + + network.train() + with tqdm( + total=len(train_loader.dataset), + desc=f'Train Epoch: 0 | Loss: 0' + ) as pbar: + for batch_idx, (data, target) in enumerate(train_loader): + network.optimizer.zero_grad() + output = network(data) + loss = F.nll_loss(output, target) + loss.backward() + network.optimizer.step() + pbar.update(len(data)) + + pbar.set_description_str( + f'Train Epoch: {epoch_number} | Loss: {loss.item():.6f}') + losses.append(loss.item()) + counter.append((batch_idx*DEFAULT_BATCH_SIZE) + ((epoch_number-1) + * len(train_loader.dataset))) + network.save(save_id) + + return losses, counter + + +def train(network, n_epochs, train_loader, test_loader, model_id, should_save_test_acc=False): + train_losses_counter = [] + test_losses = [] + test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)] + + test_losses.append(test(network, test_loader, -1, should_save_test_acc)) + for epoch in tqdm(range(1, n_epochs + 1), desc="Epochs"): + train_losses_counter.append(train_epoch( + epoch, + network, + train_loader, + save_id=model_id + )) + test_losses.append( + test(network, test_loader, epoch, should_save_test_acc)) + + train_losses, train_counter = list(map( + lambda l: np.array(l).reshape(-1), + zip(*train_losses_counter), + )) + + plt.figure() + plt.plot(train_counter, train_losses, color='blue') + plt.scatter(test_counter, test_losses, color='red') + plt.legend(['Train Loss', 'Test Loss'], loc='upper right') + plt.xlabel('number of training examples seen') + plt.ylabel('negative log likelihood loss') + plt.show() + + +@click.command() +@click.option('--train/--load', 'should_train', default=True, help='Should train or load model', show_default=True) +@click.option('--save_output', 'should_save_model_output', default=False, help="Should save the output from every validation epoch in a file", show_default=True) +@click.option("-i", "--model_id", help="The model id it should use to load/save") +@click.option("-e", "--epochs", default=5, help="How many epochs we will train the model for", show_default=True) +@click.option("-l", "--learning_rate", default=DEFAULT_LEARNING_RATE, help="The learning rate used to train the model", show_default=True) +@click.option("-w", "--weight_decay", default=DEFAULT_WEIGHT_DECAY, help="The weight decay rate used to train the model", show_default=True) +@click.option("--train_dataset", default=DEFAULT_TRAIN_DATASET, help="Regex used with `glob` to fetch the train dataset", show_default=True) +@click.option("--test_dataset", default=DEFAULT_TEST_DATASET, help="Regex used with `glob` to fetch the test dataset", show_default=True) +def main( + model_id, + epochs, + learning_rate, + weight_decay, + should_train, + train_dataset, + test_dataset, + should_save_model_output +): + global key_pressed + + if not model_id: + model_id = time.strftime("%Y%m%d_%H%M%S") + print(f"Model ID is {model_id}") + print(f"Learning rate is {learning_rate}") + + model = FINNger(FINNgerDataset.NUM_CLASSES, learning_rate, weight_decay) + + # We Guarantee the 128x128 size, even though it probably already is + data_transforms = { + 'train': transforms.Compose([ + transforms.ToPILImage(), + transforms.Resize((128, 128)), + transforms.RandomHorizontalFlip(), + transforms.RandomCrop((96, 96)), + transforms.ToTensor(), + ]), + 'test': transforms.Compose([ + transforms.ToPILImage(), + transforms.Resize((128, 128)), + transforms.CenterCrop((96, 96)), + transforms.ToTensor(), + ]), + } + + if should_train: + train_dataset = FINNgerDataset(train_dataset, data_transforms['train']) + test_dataset = FINNgerDataset(test_dataset, data_transforms['test']) + + train_dataloader = DataLoader( + train_dataset, + batch_size=DEFAULT_BATCH_SIZE, + shuffle=True, + ) + test_dataloader = DataLoader( + test_dataset, + batch_size=DEFAULT_BATCH_SIZE, + ) + + print(f"Starting to train for {epochs} epochs...") + train(model, epochs, train_dataloader, + test_dataloader, model_id, should_save_model_output) + print("Finished training!") + else: + # We can load the model and test it + model.load(model_id) + + # Initialize the enter detection thread to be used to display the values in the calculator + thread = Thread(target=detect_key_press) + thread.start() + + window = sg.Window( + 'FINNger', + [[sg.Image(filename='', key='image')], ], + location=(800, 400), + ) + + calculator = Calculator() + cap = cv2.VideoCapture(0) # Setup the camera as a capture device + while True: # The PSG "Event Loop" + # get events for the window with 20ms max wait + event, _values = window.Read(timeout=20, timeout_key='timeout') + if event is None: # if user closed window, quit + break + + image = cap.read()[1] + + model_image = np.stack( + (cv2.cvtColor(image, cv2.COLOR_RGB2GRAY),)*3, + axis=-1 + ) + model_image = data_transforms['test'](model_image) + model_image = model_image.unsqueeze(0) + + model.eval() + with torch.no_grad(): + output = model(model_image) + pred = output.data.max(1, keepdim=True)[1] + + # This value is set by a background thread + if key_pressed: + calculator.add_number(pred[0][0].item()) + key_pressed = False + + # Update image in window + image = np.array(model_image[0]).transpose((1, 2, 0)) * 255 + image = cv2.resize(image, (640, 480)) + + add_text(image, f'Detecting: {pred[0][0]}', (10, 25)) + add_text(image, str(calculator), (10, 60)) + add_text( + image, + "[{:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}]".format( + *output.data.tolist()[0]), + (10, 460), + scale=0.4, + ) + + window_image = window.FindElement('image') + encoded_image = cv2.imencode('.png', image)[1] + window_image.Update(data=encoded_image.tobytes()) + + +if __name__ == "__main__": + main() diff --git a/model.py b/model.py new file mode 100644 index 0000000..8be237c --- /dev/null +++ b/model.py @@ -0,0 +1,119 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +from functools import reduce +from operator import __add__ + +KERNEL_SIZE = (4, 4) + + +# This is used to implement a `same` padding like we'd have in Tensorflow +# For some reasons, padding dimensions are reversed wrt kernel sizes, +# first comes width then height in the 2D case. +# +# Based on [this](https://stackoverflow.com/a/63149259/9347193) StackOverflow answer +conv_padding = reduce(__add__, + [(k // 2 + (k - 2 * (k // 2)) - 1, k // 2) for k in KERNEL_SIZE[::-1]]) + + +class FINNger(nn.Module): + def __init__(self, num_classes, learning_rate, weight_decay): + super(FINNger, self).__init__() + + self.pad1_1 = nn.ZeroPad2d(conv_padding) + self.conv1_1 = nn.Conv2d(3, 64, kernel_size=KERNEL_SIZE) + self.batchnorm1_1 = nn.BatchNorm2d(64) + self.pad1_2 = nn.ZeroPad2d(conv_padding) + self.conv1_2 = nn.Conv2d(64, 64, kernel_size=KERNEL_SIZE) + self.batchnorm1_2 = nn.BatchNorm2d(64) + self.maxpooling1 = nn.MaxPool2d(2) + self.dropout1 = nn.Dropout2d(0.2) + + self.pad2_1 = nn.ZeroPad2d(conv_padding) + self.conv2_1 = nn.Conv2d(64, 128, kernel_size=KERNEL_SIZE) + self.batchnorm2_1 = nn.BatchNorm2d(128) + self.pad2_2 = nn.ZeroPad2d(conv_padding) + self.conv2_2 = nn.Conv2d(128, 128, kernel_size=KERNEL_SIZE) + self.batchnorm2_2 = nn.BatchNorm2d(128) + self.maxpooling2 = nn.MaxPool2d(2) + self.dropout2 = nn.Dropout2d(0.3) + + self.pad3_1 = nn.ZeroPad2d(conv_padding) + self.conv3_1 = nn.Conv2d(128, 128, kernel_size=KERNEL_SIZE) + self.batchnorm3_1 = nn.BatchNorm2d(128) + self.pad3_2 = nn.ZeroPad2d(conv_padding) + self.conv3_2 = nn.Conv2d(128, 128, kernel_size=KERNEL_SIZE) + self.batchnorm3_2 = nn.BatchNorm2d(128) + self.maxpooling3 = nn.MaxPool2d(2) + self.dropout3 = nn.Dropout2d(0.4) + + self.flatten = nn.Flatten() + + # Image starts with 96, and we have 3 maxpools with kernel size 2 + image_size_dense = 96 // 2 // 2 // 2 + + # Image is squared, and we have 128 layers from the convolutions + # We also output 128 layers to the output one, which then converts to num_classes + dense_out = 128 + self.dense = nn.Linear( + image_size_dense * image_size_dense * 128, dense_out) + self.out = nn.Linear(dense_out, num_classes) + + self.optimizer = optim.Adam( + self.parameters(), + lr=learning_rate, + weight_decay=weight_decay, + ) + + def forward(self, x): + # Sequence of convolutions with ReLU activations + # Shape starts with (BATCH, 3, 96, 96) + # out: BATCH, 64, 96, 96 + x = self.batchnorm1_1(F.relu(self.conv1_1(self.pad1_1(x)))) + # out: BATCH, 64, 96, 96 + x = self.batchnorm1_2(F.relu(self.conv1_2(self.pad1_2(x)))) + x = self.maxpooling1(x) # out: BATCH, 64, 48, 48 + x = self.dropout1(x) # out: BATCH, 64, 48, 48 + + # out: BATCH, 128, 48, 48 + x = self.batchnorm2_1(F.relu(self.conv2_1(self.pad2_1(x)))) + # out: BATCH, 128, 48, 48 + x = self.batchnorm2_2(F.relu(self.conv2_2(self.pad2_2(x)))) + x = self.maxpooling2(x) # out: BATCH, 128, 24, 24 + x = self.dropout2(x) # out: BATCH, 128, 24, 24 + + # out: BATCH, 128, 24, 24 + x = self.batchnorm3_1(F.relu(self.conv3_1(self.pad3_1(x)))) + # out: BATCH, 128, 24, 24 + x = self.batchnorm3_2(F.relu(self.conv3_2(self.pad3_2(x)))) + x = self.maxpooling3(x) # out: BATCH, 128, 12, 12 + x = self.dropout3(x) # out: BATCH, 128, 12, 12 + + x = self.flatten(x) # out: BATCH, 18432 + x = F.relu(self.dense(x)) # out: BATCH, 128 + x = self.out(x) # out: BATCH, NUM_CLASSES + + return F.log_softmax(x, dim=1) + + def save(self, model_id: str = ''): + torch.save(self.state_dict(), f'model/model{model_id}.pth') + torch.save(self.optimizer.state_dict(), + f'model/optimizer{model_id}.pth') + + def load(self, model_id: str = ''): + try: + print("Loading from ", f'model/model{model_id}.pth') + self.load_state_dict(torch.load(f'model/model{model_id}.pth')) + self.optimizer.load_state_dict( + torch.load(f'model/optimizer{model_id}.pth')) + self.eval() + print("Model loaded successfuly") + except FileNotFoundError as error: + error.strerror = "There is no model located on" + raise error from None + + @property + def parameter_count(self): + return sum(p.numel() for p in self.parameters() if p.requires_grad) diff --git a/model/.gitkeep b/model/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/plot.py b/plot.py new file mode 100644 index 0000000..3f14749 --- /dev/null +++ b/plot.py @@ -0,0 +1,46 @@ +import seaborn as sns +import matplotlib.pyplot as plt +import numpy as np +from collections import defaultdict, Counter +from tqdm import tqdm + +# Accuracy obtained when running 100 epochs +ACCURACY = [ + 17, 35, 43, 52, 54, 64, 65, 70, 78, 67, 70, 76, 76, 80, 78, 83, 81, 81, 86, 83, 81, 86, 84, 83, 84, + 89, 87, 90, 89, 89, 89, 87, 90, 91, 90, 91, 89, 87, 94, 90, 94, 93, 92, 91, 92, 92, 91, 91, 92, 93, + 92, 93, 94, 95, 92, 94, 91, 93, 94, 90, 92, 95, 95, 97, 95, 93, 94, 95, 93, 94, 94, 97, 96, 95, 96, + 96, 96, 96, 96, 98, 96, 96, 96, 95, 96, 94, 96, 96, 95, 95, 95, 96, 94, 96, 96, 97, 96, 96, 97, 96, 95 +] +polynomial_parameters = np.polyfit(range(len(ACCURACY)), ACCURACY, 10) +poly_fit = np.poly1d(polynomial_parameters) + +sns.set_palette("husl") +sns.lineplot(x=range(len(ACCURACY)), y=[ + poly_fit(x) for x in range(len(ACCURACY))]) +sns.lineplot(x=range(len(ACCURACY)), y=ACCURACY, lw=0.8) +plt.title("Accuracy over epochs") +plt.show() + +# Correlation plots +data = defaultdict(list) +with open("data/input.txt") as f: + for line in tqdm(f.readlines()): + expected, result = line.split(" ") + data[int(expected)].append(int(result)) + +corr = [[0] * 6 for _ in range(6)] +for opt in data: + input = data[opt] + + total = len(input) + counter = Counter(input) + for i in range(6): + corr[opt][i] = counter[i] / total + print(counter, opt) + +f, ax = plt.subplots(figsize=(10, 8)) +sns.set_theme() +sns.heatmap(corr, mask=np.zeros_like(corr, dtype=np.bool), cmap=sns.diverging_palette(220, 10, as_cmap=True), + square=True, ax=ax, annot=True) +plt.title("Correlation for NN") +plt.show() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0842baf --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +click==8.0.0 +kaggle==1.5.12 +matplotlib==3.4.0 +numpy==1.19.5 +opencv-python==4.5.1.48 +PySimpleGUI==4.41.2 +seaborn==0.11.1 +torch==1.8.1+cu111 +torch-tb-profiler==0.1.0 +torchvision==0.9.1+cu111 +tqdm==4.60.0 \ No newline at end of file