forked from aolabsai/MNIST_streamlit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_prep.py
132 lines (102 loc) · 4 KB
/
data_prep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import numpy as np
import pandas as pd
import sys
import pickle
from os import listdir
from os.path import isfile, join
# import from MNIST dataset on AWS using the instructions here: https://stackoverflow.com/a/40693405/4147579
def unpickle():
file_path = "./mnist.pkl.gz"
import gzip
f = gzip.open(file_path, "rb")
if sys.version_info < (3,):
# TODO: pyright is telling me this is unreachable, check later
data = pickle.load(f)
else:
data = pickle.load(f, encoding="bytes")
f.close()
return data
def process_labels(labels):
label_to_binary = np.zeros([10, 4], dtype="int8")
for i in np.arange(10):
label_to_binary[i] = np.array(list(np.binary_repr(i, 4)), dtype=int)
# Changing the labels from 0-9 int to binary for our weightless neural state machine
labels_z = np.zeros([labels.size, 4])
for i in np.arange(labels.size):
labels_z[i] = label_to_binary[labels[i]]
return labels_z
def process_data():
data = unpickle()
(MN_TRAIN, MN_TRAIN_labels), (MN_TEST, MN_TEST_labels) = data
MN_TRAIN_Z = process_labels(MN_TRAIN_labels)
MN_TEST_Z = process_labels(MN_TEST_labels)
return (MN_TRAIN, MN_TRAIN_Z), (MN_TEST, MN_TEST_Z)
def random_sample(num_samples, samples, labels):
index = np.random.choice(samples.shape[0], num_samples, replace=False)
return samples[index], labels[index]
def down_sample_item(x, down=200):
f = np.vectorize(lambda x, down: 1 if x >= down else 0)
return f(x, down)
def down_sample(image):
print(image.shape)
print("----------------------------------------")
print(image[0])
image_uint8 = image.astype(np.uint8)
if image_uint8.ndim != 3:
image_uint8 = image.astype(np.uint8)
image_expanded = image_uint8[:, :, np.newaxis]
print("Image shape after adding new axis:", image_expanded.shape)
bits_image = np.unpackbits(image_expanded, axis=2)
print("Bit-unpacked image shape:", bits_image.shape)
return bits_image
image_expanded = image_uint8[:, :, :, np.newaxis]
# Unpack bits along the last axis
bits_image = np.unpackbits(image_expanded, axis=3)
print(bits_image[0])
print("----------------------------------------")
print(bits_image.shape)
return bits_image
def get_font_data(filename):
df = pd.read_excel(filename)
# cut off dataframe at end of valid columns
df = df.iloc[:, :145]
arr = df.to_numpy()
# spreadsheet has two layers of digits
# this moves them into the same layer
arr = np.append(arr[:28], arr[30:58], axis=1)
arr = np.delete(arr, [28 + 29 * i for i in range(10)], axis=1)
ret_arr = np.zeros((10, 28, 28), dtype=np.uint8)
for i in range(10):
for r in range(28):
for c in range(28):
# TODO: python is complaining about some cells being cast from NaN,
# but it seems to work anyways. Find a fix for this
# Multiplying by 255 so it downsamples correctly
# print("file: {}, digit: {}, r:, {}, c: {}".format(filename, i, r, c))
ret_arr[i][r][c] = 255 * arr[r][c + 28 * i]
# Convert labels
label_to_binary = np.zeros([10, 4], dtype="int8")
for i in np.arange(10):
label_to_binary[i] = np.array(list(np.binary_repr(i, 4)), dtype=int)
return (ret_arr, np.array([label_to_binary[i] for i in range(10)]))
def get_all_fonts():
folder = "./Fonts/"
items = listdir(folder)
fonts = {}
for item in items:
if not isfile(join(folder, item)):
continue
file = join(folder, item)
fonts[item.removesuffix(".xlsx")] = get_font_data(file)
return fonts
def select_training_fonts(fonts):
inputs = []
outputs = []
for f in fonts:
if f == "MNIST":
continue
inputs.extend(FONTS[f][0])
outputs.extend(FONTS[f][1])
return np.array(inputs), np.array(outputs)
(MN_TRAIN, MN_TRAIN_Z), (MN_TEST, MN_TEST_Z) = process_data()
FONTS = get_all_fonts()