1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "view-in-github",
7
+ "colab_type": "text"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/andremicci/Generative_Deep_Learning_2nd_Edition/blob/main/andremicci/Generative_Deep_Learning_2nd_Edition/notebooks/02_deeplearning/02_cnn\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "source": [],
16
+ "metadata": {
17
+ "id": "Mhp4GfwTjOee"
18
+ },
19
+ "execution_count": null,
20
+ "outputs": []
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "source": [
25
+ "!git clone https://github.com/andremicci/Generative_Deep_Learning_2nd_Edition.git"
26
+ ],
27
+ "metadata": {
28
+ "colab": {
29
+ "base_uri": "https://localhost:8080/"
30
+ },
31
+ "id": "KOnmreJbjLIz",
32
+ "outputId": "a3a49f71-3b29-41cb-fc11-a031442b32a9"
33
+ },
34
+ "execution_count": null,
35
+ "outputs": [
36
+ {
37
+ "output_type": "stream",
38
+ "name": "stdout",
39
+ "text": [
40
+ "Cloning into 'Generative_Deep_Learning_2nd_Edition'...\n",
41
+ "remote: Enumerating objects: 653, done.\u001b[K\n",
42
+ "remote: Counting objects: 100% (195/195), done.\u001b[K\n",
43
+ "remote: Compressing objects: 100% (70/70), done.\u001b[K\n",
44
+ "remote: Total 653 (delta 142), reused 125 (delta 125), pack-reused 458 (from 1)\u001b[K\n",
45
+ "Receiving objects: 100% (653/653), 37.09 MiB | 8.67 MiB/s, done.\n",
46
+ "Resolving deltas: 100% (377/377), done.\n"
47
+ ]
48
+ }
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "markdown",
53
+ "metadata": {
54
+ "id": "TFy-fuvQjJJp"
55
+ },
56
+ "source": [
57
+ "# 🏞 Convolutional Neural Network"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "markdown",
62
+ "metadata": {
63
+ "id": "wEIaaUnijJJr"
64
+ },
65
+ "source": [
66
+ "In this notebook, we'll walk through the steps required to train your own convolutional neural network (CNN) on the CIFAR dataset"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "metadata": {
73
+ "colab": {
74
+ "base_uri": "https://localhost:8080/",
75
+ "height": 350
76
+ },
77
+ "id": "Ewl7DskxjJJs",
78
+ "outputId": "72f7b930-03ff-4bab-97c7-8f337c0da519"
79
+ },
80
+ "outputs": [
81
+ {
82
+ "output_type": "error",
83
+ "ename": "ModuleNotFoundError",
84
+ "evalue": "No module named 'notebooks'",
85
+ "traceback": [
86
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
87
+ "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
88
+ "\u001b[0;32m<ipython-input-2-a230bc8cc151>\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtensorflow\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeras\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mlayers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mnotebooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdisplay\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
89
+ "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'notebooks'",
90
+ "",
91
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0;32m\nNOTE: If your import is failing due to a missing package, you can\nmanually install dependencies using either !pip or !apt.\n\nTo view examples of installing some common dependencies, click the\n\"Open Examples\" button below.\n\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n"
92
+ ],
93
+ "errorDetails": {
94
+ "actions": [
95
+ {
96
+ "action": "open_url",
97
+ "actionText": "Open Examples",
98
+ "url": "/notebooks/snippets/importing_libraries.ipynb"
99
+ }
100
+ ]
101
+ }
102
+ }
103
+ ],
104
+ "source": [
105
+ "import numpy as np\n",
106
+ "\n",
107
+ "from tensorflow.keras import layers, models, optimizers, utils, datasets\n",
108
+ "from notebooks.utils import display"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "markdown",
113
+ "metadata": {
114
+ "tags": [],
115
+ "id": "5xHzrj8jjJJu"
116
+ },
117
+ "source": [
118
+ "## 0. Parameters <a name=\"parameters\"></a>"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "metadata": {
125
+ "id": "4dTmq7jjjJJv"
126
+ },
127
+ "outputs": [],
128
+ "source": [
129
+ "NUM_CLASSES = 10"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "markdown",
134
+ "metadata": {
135
+ "id": "MTLZ1IG6jJJw"
136
+ },
137
+ "source": [
138
+ "## 1. Prepare the Data <a name=\"prepare\"></a>"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "metadata": {
145
+ "id": "DR22w6jsjJJx"
146
+ },
147
+ "outputs": [],
148
+ "source": [
149
+ "(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": null,
155
+ "metadata": {
156
+ "id": "gkQBN7pojJJx"
157
+ },
158
+ "outputs": [],
159
+ "source": [
160
+ "x_train = x_train.astype(\"float32\") / 255.0\n",
161
+ "x_test = x_test.astype(\"float32\") / 255.0\n",
162
+ "\n",
163
+ "y_train = utils.to_categorical(y_train, NUM_CLASSES)\n",
164
+ "y_test = utils.to_categorical(y_test, NUM_CLASSES)"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": null,
170
+ "metadata": {
171
+ "id": "p6D1z7qPjJJz"
172
+ },
173
+ "outputs": [],
174
+ "source": [
175
+ "display(x_train[:10])\n",
176
+ "print(y_train[:10])"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "markdown",
181
+ "metadata": {
182
+ "id": "uqly1nGYjJJ0"
183
+ },
184
+ "source": [
185
+ "## 2. Build the model <a name=\"build\"></a>"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": null,
191
+ "metadata": {
192
+ "id": "rvLwhVf7jJJ1"
193
+ },
194
+ "outputs": [],
195
+ "source": [
196
+ "input_layer = layers.Input((32, 32, 3))\n",
197
+ "\n",
198
+ "x = layers.Conv2D(filters=32, kernel_size=3, strides=1, padding=\"same\")(\n",
199
+ " input_layer\n",
200
+ ")\n",
201
+ "x = layers.BatchNormalization()(x)\n",
202
+ "x = layers.LeakyReLU()(x)\n",
203
+ "\n",
204
+ "x = layers.Conv2D(filters=32, kernel_size=3, strides=2, padding=\"same\")(x)\n",
205
+ "x = layers.BatchNormalization()(x)\n",
206
+ "x = layers.LeakyReLU()(x)\n",
207
+ "\n",
208
+ "x = layers.Conv2D(filters=64, kernel_size=3, strides=1, padding=\"same\")(x)\n",
209
+ "x = layers.BatchNormalization()(x)\n",
210
+ "x = layers.LeakyReLU()(x)\n",
211
+ "\n",
212
+ "x = layers.Conv2D(filters=64, kernel_size=3, strides=2, padding=\"same\")(x)\n",
213
+ "x = layers.BatchNormalization()(x)\n",
214
+ "x = layers.LeakyReLU()(x)\n",
215
+ "\n",
216
+ "x = layers.Flatten()(x)\n",
217
+ "\n",
218
+ "x = layers.Dense(128)(x)\n",
219
+ "x = layers.BatchNormalization()(x)\n",
220
+ "x = layers.LeakyReLU()(x)\n",
221
+ "x = layers.Dropout(rate=0.5)(x)\n",
222
+ "\n",
223
+ "x = layers.Dense(NUM_CLASSES)(x)\n",
224
+ "output_layer = layers.Activation(\"softmax\")(x)\n",
225
+ "\n",
226
+ "model = models.Model(input_layer, output_layer)\n",
227
+ "\n",
228
+ "model.summary()"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "markdown",
233
+ "metadata": {
234
+ "tags": [],
235
+ "id": "_0JEpfyWjJJ3"
236
+ },
237
+ "source": [
238
+ "## 3. Train the model <a name=\"train\"></a>"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": null,
244
+ "metadata": {
245
+ "id": "RQWD-6i6jJJ4"
246
+ },
247
+ "outputs": [],
248
+ "source": [
249
+ "opt = optimizers.Adam(learning_rate=0.0005)\n",
250
+ "model.compile(\n",
251
+ " loss=\"categorical_crossentropy\", optimizer=opt, metrics=[\"accuracy\"]\n",
252
+ ")"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "execution_count": null,
258
+ "metadata": {
259
+ "tags": [],
260
+ "id": "pajMd0T0jJJ5"
261
+ },
262
+ "outputs": [],
263
+ "source": [
264
+ "model.fit(\n",
265
+ " x_train,\n",
266
+ " y_train,\n",
267
+ " batch_size=32,\n",
268
+ " epochs=10,\n",
269
+ " shuffle=True,\n",
270
+ " validation_data=(x_test, y_test),\n",
271
+ ")"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "markdown",
276
+ "metadata": {
277
+ "tags": [],
278
+ "id": "M2Jl9j7ljJJ5"
279
+ },
280
+ "source": [
281
+ "## 4. Evaluation <a name=\"evaluate\"></a>"
282
+ ]
283
+ },
284
+ {
285
+ "cell_type": "code",
286
+ "execution_count": null,
287
+ "metadata": {
288
+ "id": "t7GXyVvYjJJ6"
289
+ },
290
+ "outputs": [],
291
+ "source": [
292
+ "model.evaluate(x_test, y_test, batch_size=1000)"
293
+ ]
294
+ },
295
+ {
296
+ "cell_type": "code",
297
+ "execution_count": null,
298
+ "metadata": {
299
+ "id": "Ttvkm2g_jJJ6"
300
+ },
301
+ "outputs": [],
302
+ "source": [
303
+ "CLASSES = np.array(\n",
304
+ " [\n",
305
+ " \"airplane\",\n",
306
+ " \"automobile\",\n",
307
+ " \"bird\",\n",
308
+ " \"cat\",\n",
309
+ " \"deer\",\n",
310
+ " \"dog\",\n",
311
+ " \"frog\",\n",
312
+ " \"horse\",\n",
313
+ " \"ship\",\n",
314
+ " \"truck\",\n",
315
+ " ]\n",
316
+ ")\n",
317
+ "\n",
318
+ "preds = model.predict(x_test)\n",
319
+ "preds_single = CLASSES[np.argmax(preds, axis=-1)]\n",
320
+ "actual_single = CLASSES[np.argmax(y_test, axis=-1)]"
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "code",
325
+ "execution_count": null,
326
+ "metadata": {
327
+ "id": "5SkFQ0pHjJJ7"
328
+ },
329
+ "outputs": [],
330
+ "source": [
331
+ "import matplotlib.pyplot as plt\n",
332
+ "\n",
333
+ "n_to_show = 10\n",
334
+ "indices = np.random.choice(range(len(x_test)), n_to_show)\n",
335
+ "\n",
336
+ "fig = plt.figure(figsize=(15, 3))\n",
337
+ "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
338
+ "\n",
339
+ "for i, idx in enumerate(indices):\n",
340
+ " img = x_test[idx]\n",
341
+ " ax = fig.add_subplot(1, n_to_show, i + 1)\n",
342
+ " ax.axis(\"off\")\n",
343
+ " ax.text(\n",
344
+ " 0.5,\n",
345
+ " -0.35,\n",
346
+ " \"pred = \" + str(preds_single[idx]),\n",
347
+ " fontsize=10,\n",
348
+ " ha=\"center\",\n",
349
+ " transform=ax.transAxes,\n",
350
+ " )\n",
351
+ " ax.text(\n",
352
+ " 0.5,\n",
353
+ " -0.7,\n",
354
+ " \"act = \" + str(actual_single[idx]),\n",
355
+ " fontsize=10,\n",
356
+ " ha=\"center\",\n",
357
+ " transform=ax.transAxes,\n",
358
+ " )\n",
359
+ " ax.imshow(img)"
360
+ ]
361
+ }
362
+ ],
363
+ "metadata": {
364
+ "kernelspec": {
365
+ "display_name": "Python 3 (ipykernel)",
366
+ "language": "python",
367
+ "name": "python3"
368
+ },
369
+ "language_info": {
370
+ "codemirror_mode": {
371
+ "name": "ipython",
372
+ "version": 3
373
+ },
374
+ "file_extension": ".py",
375
+ "mimetype": "text/x-python",
376
+ "name": "python",
377
+ "nbconvert_exporter": "python",
378
+ "pygments_lexer": "ipython3",
379
+ "version": "3.8.10"
380
+ },
381
+ "vscode": {
382
+ "interpreter": {
383
+ "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
384
+ }
385
+ },
386
+ "colab": {
387
+ "provenance": [],
388
+ "include_colab_link": true
389
+ }
390
+ },
391
+ "nbformat": 4,
392
+ "nbformat_minor": 0
393
+ }
0 commit comments