From a14b3c5553799d87bca0c9e15430c243447b500c Mon Sep 17 00:00:00 2001
From: Brian Nord <184985+bnord@users.noreply.github.com>
Date: Mon, 18 Nov 2024 11:59:50 -0600
Subject: [PATCH 1/5] upload notebook tutorial
---
...geClassificationWithTensorflow_Draft.ipynb | 1870 +++++++++++++++++
1 file changed, 1870 insertions(+)
create mode 100644 AI0_Intro_AI_ImageClassificationWithTensorflow_Draft.ipynb
diff --git a/AI0_Intro_AI_ImageClassificationWithTensorflow_Draft.ipynb b/AI0_Intro_AI_ImageClassificationWithTensorflow_Draft.ipynb
new file mode 100644
index 0000000..9939bcd
--- /dev/null
+++ b/AI0_Intro_AI_ImageClassificationWithTensorflow_Draft.ipynb
@@ -0,0 +1,1870 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ " \n",
+ " AI0: Introduction to AI-based Image Classification with Tensorflow \n",
+ "Contact author: Brian Nord \n",
+ "Last verified to run: YYYY-MM-DD \n",
+ "LSST Science Pipelines version: ?? \n",
+ "Container size: medium \n",
+ "Targeted learning level: beginner "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "jce50kKEfHC1"
+ },
+ "source": [
+ "**Description:** An introduction to the classification of images with AI-based classification algorithms."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "jce50kKEfHC1"
+ },
+ "source": [
+ "**Skills:** Examine AI training data, prepare it for a classification task, perform classification with a neural network, and examine the diagnostics of the classification task."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**LSST Data Products:** None; MNIST data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Packages:** numpy, matplotlib, sklearn, tensorflow"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Credits and Acknowledgments:** None"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Get Support:**\n",
+ "Find DP0-related documentation and resources at dp0.lsst.io. Questions are welcome as new topics in the Support - Data Preview 0 Category of the Rubin Community Forum. Rubin staff will respond to all questions posted there."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 1. Introduction\n",
+ "\n",
+ "This Jupyter Notebook introduces artificial intelligence (AI)-based image classification. It demonstrates how to perform a few key steps:\n",
+ "1. examine and prepare data for classification;\n",
+ "2. train an AI algorithm;\n",
+ "3. plot diagnostics of the training performance;\n",
+ "4. initially assess those diagnostics. \n",
+ "\n",
+ "AI is a class of algorithms for building statistical models. These algorithms primarily use data for training, as opposed to models that use analytic formulae or models that are based on physical reasoning. Machine learning is a subclass of algorithms -- e.g., random forests. Deep learning is a subclass of algorithms -- e.g., neural networks. \n",
+ "\n",
+ "This notebook uses `tensorflow`, one of the two most commonly used `python` libraries for deep learning. `Tensorflow` is often easier to use because of how it handles data sets and the logic used for model building. However, it is typically also difficult to develop network models creatively. We use `tensorflow` first in this series of tutorials so that users who are new to deep learning can focus on learning AI. In later tutorials, we will use `pytorch` because it is more flexible and more commonly used in science applications. \n",
+ "\n",
+ "This notebook uses [MNIST AI benchmarking data](https://en.wikipedia.org/wiki/MNIST_database). In a future notebook, we will we'll use stars and galaxies drawn from DP0 data.\n",
+ "\n",
+ "The use of data in this notebook requires a medium-sized ram allocation (8Gi).\n",
+ "\n",
+ "The end of this notebook contains a Glossary of Terms and a comment regarding usage of terms in AI contexts."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%reload_ext pycodestyle_magic\n",
+ "%flake8_on\n",
+ "import logging\n",
+ "logging.getLogger(\"flake8\").setLevel(logging.FATAL)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "V3xHhKu6c5-e"
+ },
+ "source": [
+ "### 1.1. Import Packages\n",
+ "\n",
+ "[`numpy`](https://numpy.org/) is a widely used Python library for computations and mathematical operations on multi-dimensional arrays.\n",
+ "\n",
+ "[`matplotlib`](https://matplotlib.org/) is a widely used Python plot library. \n",
+ "\n",
+ "[`tensorflow`](https://www.tensorflow.org) is a widely used library from Google for fast tensor operations --- often used for building neural network models. \n",
+ "\n",
+ "[`sklearn`](https://scikit-learn.org/stable/) is a library for machine learning."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "puW54XTfdo1C"
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import os\n",
+ "import datetime\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "from matplotlib.pyplot import cm\n",
+ "\n",
+ "import tensorflow as tf\n",
+ "\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix\n",
+ "from sklearn.metrics import roc_curve, roc_auc_score, auc, RocCurveDisplay\n",
+ "from sklearn.preprocessing import LabelBinarizer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.2 Define Functions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def normalizeInputs(x_temp, input_minimum, input_maximum):\n",
+ " \"\"\"Normalize a datum that is an input to the neural network\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " x_temp: `numpy.array`\n",
+ " image data\n",
+ " input_minimum: `float`\n",
+ " minimum value for normalization\n",
+ " input_maximum: `float`\n",
+ " maximum value for normalization\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " x_temp_norm: `numpy.array`\n",
+ " normalized image data\n",
+ " \"\"\"\n",
+ " x_temp_norm = (x_temp - input_minimum)/input_maximum\n",
+ " return x_temp_norm"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def createFileUidTimestamp():\n",
+ " \"\"\"Create a timestamp for a filename.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " None\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " file_uid_timestamp : `string`\n",
+ " String from date and time.\n",
+ " \"\"\"\n",
+ " file_uid_timestamp = datetime.datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
+ " return file_uid_timestamp\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def createFileName(file_prefix=\"\", file_location=\"Data/Sandbox/\",\n",
+ " file_suffix=\"\", useuid=True, verbose=True):\n",
+ " \"\"\"Create a file name.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " file_prefix: `string`\n",
+ " prefix of file name\n",
+ " file_location: `string`\n",
+ " path to file\n",
+ " file_suffix: `string`\n",
+ " suffix/extension of file name\n",
+ " useuid: 'bool'\n",
+ " choose to use a unique id\n",
+ " verbose: 'bool'\n",
+ " choose to print the file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " file_final: `string`\n",
+ " filename used for saving\n",
+ " \"\"\"\n",
+ " if useuid:\n",
+ " file_uid = createFileUidTimestamp()\n",
+ " else:\n",
+ " file_uid = \"\"\n",
+ "\n",
+ " file_final = file_location + file_prefix + \"_\" + file_uid + file_suffix\n",
+ "\n",
+ " if verbose:\n",
+ " print(file_final)\n",
+ "\n",
+ " return file_final\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plotArrayImageExamples(x_tra, y_tra, num=10,\n",
+ " save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot an array of examples of images and labels\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " x_tra: `numpy.ndarray`\n",
+ " training data images\n",
+ " y_tra: `numpy.ndarray`\n",
+ " training data labels\n",
+ " num: `int`, optional\n",
+ " number examples to plot\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " num_row = 2\n",
+ " num_col = 5\n",
+ " images = x_tra[:num]\n",
+ " labels = y_tra[:num]\n",
+ "\n",
+ " fig, axes = plt.subplots(num_row, num_col,\n",
+ " figsize=(1.5*num_col, 2*num_row))\n",
+ " for i in range(num):\n",
+ " ax = axes[i//num_col, i%num_col]\n",
+ " ax.imshow(images[i], cmap='gray')\n",
+ " ax.set_title('Label: {}'.format(labels[i]))\n",
+ "\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plotROCMulticlassOnevsrest(y_tra, y_tes, y_pred_tes, label_target_list,\n",
+ " color_list,\n",
+ " save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot Receiver Operator Curve for one-vs-rest scenario\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " y_tra: `numpy.ndarray`\n",
+ " training data images\n",
+ " y_tes: `numpy.ndarray`\n",
+ " test data images\n",
+ " y_pred_tes: `numpy.ndarray`\n",
+ " test data predicted labels\n",
+ " label_target_list: 'list'\n",
+ " color_list: 'list'\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " file_final: `string`\n",
+ " \"\"\"\n",
+ " fig, ax = plt.subplots(figsize=(6, 6))\n",
+ "\n",
+ " for label_target, color in zip(label_target_list, color_list):\n",
+ "\n",
+ " label_binarizer = LabelBinarizer().fit(y_tra)\n",
+ " y_onehot_tes = label_binarizer.transform(y_tes)\n",
+ "\n",
+ " class_id = np.flatnonzero(label_binarizer.classes_ == label_target)[0]\n",
+ "\n",
+ " display = RocCurveDisplay.from_predictions(\n",
+ " y_onehot_tes[:, class_id],\n",
+ " y_pred_tes[:, class_id],\n",
+ " name=f\"{label_target} vs the rest\",\n",
+ " color=color,\n",
+ " ax=ax,\n",
+ " plot_chance_level=(class_id == 0)\n",
+ " )\n",
+ "\n",
+ " _ = display.ax_.set(\n",
+ " xlabel=\"False Positive Rate\",\n",
+ " ylabel=\"True Positive Rate\",\n",
+ " title=\"ROC: One-vs-Rest\",\n",
+ " )\n",
+ "\n",
+ " if save_file:\n",
+ " createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plotROCMulticlassOnevsone(y_tra, y_tes, y_pred_tes, label_target_list,\n",
+ " color_list, save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot Receiver Operator Curve for one-vs-one scenario\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " y_tra: `numpy.ndarray`\n",
+ " training data true labels\n",
+ " y_tes: `numpy.ndarray`\n",
+ " test data true labels\n",
+ " y_pred_tes: `numpy.ndarray`\n",
+ " test data predicted labels\n",
+ " label_target_list: 'list'\n",
+ " color_list: 'list'\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " fig, ax = plt.subplots(figsize=(6, 6))\n",
+ "\n",
+ " for label_target, color in zip(label_target_list, color_list):\n",
+ "\n",
+ " label_binarizer = LabelBinarizer().fit(y_tra)\n",
+ " y_onehot_tes = label_binarizer.transform(y_tes)\n",
+ "\n",
+ " class_id = np.flatnonzero(label_binarizer.classes_ == label_target)[0]\n",
+ "\n",
+ " display = RocCurveDisplay.from_predictions(\n",
+ " y_onehot_tes[:, class_id],\n",
+ " y_pred_tes[:, class_id],\n",
+ " name=f\"{label_target} vs the rest\",\n",
+ " color=color,\n",
+ " ax=ax,\n",
+ " plot_chance_level=(class_id == 0)\n",
+ " )\n",
+ "\n",
+ " _ = display.ax_.set(\n",
+ " xlabel=\"False Positive Rate\",\n",
+ " ylabel=\"True Positive Rate\",\n",
+ " title=\"ROC: One-vs-Rest\",\n",
+ " )\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ "\n",
+ " plt.savefig(file_final, bbox_inches='tight')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plotArrayHistogramExamples(x_tra, y_tra, num=10,\n",
+ " save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot histograms of image pixel values\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " x_tra: `numpy.ndarray`\n",
+ " training image data\n",
+ " y_tra: `numpy.ndarray`\n",
+ " training label data\n",
+ " num: `int`, optional\n",
+ " number of examples to show\n",
+ " file_prefix: 'string', optional\n",
+ " prefix of file name\n",
+ " file_location: 'string', optional\n",
+ " path to file\n",
+ " file_suffix: 'string', optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " n_bins = 10\n",
+ " num = 10\n",
+ " num_row = 2\n",
+ " num_col = 5\n",
+ " images = x_tra[:num]\n",
+ " labels = y_tra[:num]\n",
+ "\n",
+ " fig, axes = plt.subplots(num_row, num_col,\n",
+ " figsize=(1.5*num_col, 2*num_row))\n",
+ "\n",
+ " for i in range(num):\n",
+ " ax = axes[i//num_col, i%num_col]\n",
+ " ax.hist(images[i], bins=n_bins)\n",
+ " ax.set_title('Label: {}'.format(labels[i]))\n",
+ "\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ "\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plotPredictionHistogram(y_prediction_a, y_prediction_b=None,\n",
+ " y_prediction_c=None, n_classes=None,\n",
+ " n_objects_a=None, n_colors=None,\n",
+ " title_a=None, title_b=None,\n",
+ " title_c=None, label_a=None,\n",
+ " label_b=None, label_c=None,\n",
+ " alpha=0.5, figsize=(12, 5),\n",
+ " save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot histogram of predicted labels\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " y_prediction_a: `numpy.ndarray`\n",
+ " y_prediction_b: `numpy.ndarray`, optional\n",
+ " y_prediction_c: `numpy.ndarray`, optional\n",
+ " n_classes: `int`, optional\n",
+ " n_objects_a: `int`, optional\n",
+ " n_colors: `int`, optional\n",
+ " title_a: `string`, optional\n",
+ " title_b: `string`, optional\n",
+ " title_c: `string`, optional\n",
+ " label_a: `string`, optional\n",
+ " label_b: `string`, optional\n",
+ " label_c: `string`, optional\n",
+ " alpha: `float`, optional\n",
+ " transparency\n",
+ " figsize: `tuple`, optional\n",
+ " figure size\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " ndim = y_prediction_a.ndim\n",
+ "\n",
+ " if ndim == 2:\n",
+ " fig, (axa, axb, axc) = plt.subplots(1, 3, figsize=figsize)\n",
+ " fig.subplots_adjust(wspace=0.35)\n",
+ " elif ndim == 1:\n",
+ " fig, ax = plt.subplots(figsize=figsize)\n",
+ "\n",
+ " shape_a = np.shape(y_prediction_a)\n",
+ "\n",
+ " if n_objects_a is None:\n",
+ " n_objects_a = shape_a[0]\n",
+ "\n",
+ " if ndim == 2:\n",
+ " if n_classes == None:\n",
+ " n_classes = shape_a[1]\n",
+ " if n_colors is None:\n",
+ " n_colors = n_classes\n",
+ " elif ndim == 1:\n",
+ " if n_colors is None:\n",
+ " n_colors = 1\n",
+ "\n",
+ " if ndim == 2:\n",
+ " colors = cm.Purples(np.linspace(0, 1, n_colors))\n",
+ " xlabel = \"Probability for Each Class\"\n",
+ "\n",
+ " axa.set_ylim(0, n_objects_a)\n",
+ " axa.set_xlabel(xlabel)\n",
+ " axa.set_title(title_a)\n",
+ "\n",
+ " for i in np.arange(n_classes):\n",
+ " axa.hist(y_prediction_a[:, i], alpha=alpha,\n",
+ " color=colors[i], label=\"'\" + str(i) + \"'\")\n",
+ "\n",
+ " if y_prediction_b is not None:\n",
+ " shape_b = np.shape(y_prediction_b)\n",
+ " axb.set_ylim(0, shape_b[0])\n",
+ " axb.set_xlabel(xlabel)\n",
+ " axb.set_title(title_b)\n",
+ "\n",
+ " for i in np.arange(n_classes):\n",
+ " axb.hist(y_prediction_b[:, i], alpha=alpha,\n",
+ " color=colors[i], label=\"'\" + str(i) + \"'\")\n",
+ "\n",
+ " if y_prediction_c is not None:\n",
+ " shape_c = np.shape(y_prediction_c)\n",
+ " axc.set_ylim(0, shape_c[0])\n",
+ " axc.set_xlabel(xlabel)\n",
+ " axc.set_title(title_c)\n",
+ "\n",
+ " for i in np.arange(n_classes):\n",
+ " axc.hist(y_prediction_c[:, i], alpha=alpha,\n",
+ " color=colors[i], label=\"'\" + str(i) + \"'\")\n",
+ "\n",
+ " elif ndim == 1:\n",
+ " ya, xa, _ = plt.hist(y_prediction_a, alpha=alpha, color='purple',\n",
+ " label=label_a)\n",
+ " y_max_list = [max(ya)]\n",
+ "\n",
+ " if y_prediction_b is not None:\n",
+ " yb, xb, _ = plt.hist(y_prediction_b, alpha=alpha, color='blue',\n",
+ " label=label_b)\n",
+ " y_max_list.append(max(yb))\n",
+ "\n",
+ " if y_prediction_c is not None:\n",
+ " yc, xc, _ = plt.hist(y_prediction_c, alpha=alpha, color='green',\n",
+ " label=label_c)\n",
+ " y_max_list.append(max(yc))\n",
+ "\n",
+ " plt.ylim(0, np.max(y_max_list)*1.1)\n",
+ " plt.xlabel(\"Top Choice-Class\")\n",
+ "\n",
+ " plt.legend(loc='upper right')\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plotLossHistory(history, figsize=(8, 5),\n",
+ " save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot loss history of the model as function of epoch\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " history: `keras.src.callbacks.history.History`\n",
+ " keras callback history object containing the losses at each epoch\n",
+ " figsize: `tuple`, optional\n",
+ " figure size\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=figsize)\n",
+ "\n",
+ " loss_tra = np.array(history.history['loss'])\n",
+ " loss_val = np.array(history.history['val_loss'])\n",
+ " loss_dif = loss_val - loss_tra\n",
+ "\n",
+ " ax1.plot(loss_tra, label='Training')\n",
+ " ax1.plot(loss_val, label='Validation')\n",
+ " ax1.legend()\n",
+ "\n",
+ " ax2.plot(loss_dif, color='red', label='residual')\n",
+ " ax2.axhline(y=0, color='grey', linestyle='dashed', label='zero bias')\n",
+ " ax2.sharex(ax1)\n",
+ " ax2.legend()\n",
+ "\n",
+ " ax1.set_title('Loss History')\n",
+ " ax1.set_ylabel('Loss')\n",
+ " ax2.set_ylabel('Loss Residual')\n",
+ " ax2.set_xlabel('Epoch')\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ "\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plotConfusionMatrix(cm_tra, cm_val, cm_tes, save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot the confusion matrix of predictions.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " confusion_matrix_tra: `numpy.ndarray`\n",
+ " confusion matrix for the training data\n",
+ " confusion_matrix_val: `numpy.ndarray`\n",
+ " confusion matrix for the validation data\n",
+ " confusion_matrix_tes: `numpy.ndarray`\n",
+ " confusion matrix for the test data\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ "\n",
+ " cm_display_tra = ConfusionMatrixDisplay(confusion_matrix=cm_tra)\n",
+ " cm_display_val = ConfusionMatrixDisplay(confusion_matrix=cm_val)\n",
+ " cm_display_tes = ConfusionMatrixDisplay(confusion_matrix=cm_tes)\n",
+ "\n",
+ " fig, (axa, axb, axc) = plt.subplots(1, 3, figsize=(22, 5))\n",
+ "\n",
+ " cm_display_tra.plot(ax=axa)\n",
+ " cm_display_val.plot(ax=axb)\n",
+ " cm_display_tes.plot(ax=axc)\n",
+ "\n",
+ " axa.set_title(\"Training\")\n",
+ " axb.set_title(\"Validation\")\n",
+ " axc.set_title(\"Testing\")\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ "\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plotArrayImageConfusion(x_tra, y_tra, y_pred_tra_topchoice,\n",
+ " title_main=None, num=10,\n",
+ " save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot images of examples objects that are misclassified.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " x_tra: `numpy.ndarray`\n",
+ " training image data\n",
+ " y_tra: `numpy.ndarray`\n",
+ " training label data\n",
+ " y_pred_tra_topchoice: `numpy.ndarray`\n",
+ " top choice of the predicted labels\n",
+ " title_main: `string`, optional\n",
+ " title for the plot\n",
+ " num: `int`, optional\n",
+ " number of examples\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " num_row = 2\n",
+ " num_col = 5\n",
+ " images = x_tra[:num]\n",
+ " labels_true = y_tra[:num]\n",
+ " labels_pred = y_pred_tra_topchoice[:num]\n",
+ "\n",
+ " fig, axes = plt.subplots(num_row, num_col,\n",
+ " figsize=(1.5*num_col, 2*num_row))\n",
+ "\n",
+ " fig.patch.set_linewidth(5)\n",
+ " fig.patch.set_edgecolor('cornflowerblue')\n",
+ "\n",
+ " for i in range(num):\n",
+ " ax = axes[i//num_col, i%num_col]\n",
+ " ax.imshow(images[i], cmap='gray')\n",
+ " ax.set_title(r'True: {}'.format(labels_true[i]) + '\\n'\n",
+ " + 'Pred: {}'.format(labels_pred[i]))\n",
+ "\n",
+ " fig.suptitle(title_main)\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ "\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plotArrayHistogramConfusion(x_tra, y_tra, y_pred_tra_topchoice,\n",
+ " title_main=None, num=10,\n",
+ " save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot histograms of pixel values for images that are misclassified.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " x_tra: `numpy.ndarray`\n",
+ " training image data\n",
+ " y_tra: `numpy.ndarray`\n",
+ " training label data\n",
+ " y_pred_tra_topchoice: `numpy.ndarray`\n",
+ " top choice of the predicted labels\n",
+ " title_main: `string`, optional\n",
+ " title of plot\n",
+ " num: `int`, optional\n",
+ " number of examples\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " n_bins = 10\n",
+ " num_row = 2\n",
+ " num_col = 5\n",
+ " images = x_tra[:num]\n",
+ " labels_true = y_tra[:num]\n",
+ " labels_pred = y_pred_tra_topchoice[:num]\n",
+ "\n",
+ " fig, axes = plt.subplots(num_row, num_col,\n",
+ " figsize=(1.5*num_col, 2*num_row))\n",
+ "\n",
+ " fig.patch.set_linewidth(5)\n",
+ " fig.patch.set_edgecolor('cornflowerblue')\n",
+ "\n",
+ " for i in range(num):\n",
+ " ax = axes[i//num_col, i%num_col]\n",
+ " ax.hist(images[i], bins=n_bins)\n",
+ " ax.set_title(r'True: {}'.format(labels_true[i]) + '\\n'\n",
+ " + 'Pred: {}'.format(labels_pred[i]))\n",
+ "\n",
+ " fig.suptitle(title_main)\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ "\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.3 Define Paths for Data and Plots"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Neural network training (i.e., model fitting) typically requires many numerical experiments to achieve an ideal model. To facilitate the comparison of these experiments/models, it is helpful to organize data carefully. We set paths for the model weight parameters and diagnostic figures. We also set the variable `run_label` for each training run. We also save these paths in a dictionary to facilitate passing information to plotting functions."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "run_label = \"Run000\"\n",
+ "\n",
+ "path_dict = {'run_label': run_label,\n",
+ " 'dir_data_model': \"Data/Models/\",\n",
+ " 'dir_data_figures': \"Data/Figures/\",\n",
+ " 'file_model_prefix': \"Model\",\n",
+ " 'file_figure_prefix': \"Figure\",\n",
+ " 'file_figure_suffix': \".png\",\n",
+ " 'file_model_suffix': \".keras\"\n",
+ " }\n",
+ "\n",
+ "if not os.path.exists(path_dict['dir_data_model']):\n",
+ " os.makedirs(path_dict['dir_data_model'])\n",
+ "\n",
+ "if not os.path.exists(path_dict['dir_data_figures']):\n",
+ " os.makedirs(path_dict['dir_data_figures'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "jce50kKEfHC1"
+ },
+ "source": [
+ "## 2. Load and Prepare data: MNIST Handwritten Digits"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-dHXIDGEfLmO"
+ },
+ "source": [
+ "### 2.1. Download Dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-dHXIDGEfLmO"
+ },
+ "source": [
+ "The [`MNIST handwritten digits dataset`](https://ieeexplore.ieee.org/document/6296535) comprises 10 classes --- one for each digit. This is a useful dataset for learning the basics of neural networks and other AI algorithms. MNIST is one of a few canonical AI benchmark data sets for image classification. `tensorflow` has a simple function easily downloading the MNIST data to your local server for free. It automatically downloads the data into.\n",
+ "\n",
+ "The **input** data are held in `x_`, while the **output** (aka, label) data are held in `y_`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "zDRvT2QkfISn",
+ "outputId": "68276f81-5e32-443a-d9c3-5291ef61715c"
+ },
+ "outputs": [],
+ "source": [
+ "mnist = tf.keras.datasets.mnist\n",
+ "train_temp, test_temp = mnist.load_data()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "O4FPxkLKiJKe"
+ },
+ "source": [
+ "### 2.2. Split Data into Train/Validation/Test"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "O4FPxkLKiJKe"
+ },
+ "source": [
+ "It is essential to split for a proper 'blind' analysis and optimization of an AI model.\n",
+ "\n",
+ "There are three primary data sets used in model optimization:\n",
+ "\n",
+ "* **Training** (`_tra`) data is used directly by the algorithm to update the parameters of the AI model -- e.g., the weights of the computational neurons on the edges in neural networks.\n",
+ "* **Validation** (`_val`) data is used indirectly to update the hyperparameters of the AI model -- e.g., the batchsize, the learning rate, or the layers in the architecture of a neural network. Each time the neural network has completed training with the training data, the human looks at those diagnostics when run on the training and the validation data.\n",
+ "* **Test(ing)** (`_tes`) data is only used when the model is trained and validated and will no longer be update or further trained. \n",
+ "\n",
+ "The `TF` class automatically downloads data into training and test data sets. Therefore, we use the `sklearn` `train_test_split()` function to further split the training set into training and validation data sets. We then \n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "RgjGyErDfNrg"
+ },
+ "outputs": [],
+ "source": [
+ "fraction_validation = 0.25"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# set the test data sets from the temp data at read-in\n",
+ "x_tes, y_tes = test_temp[0], test_temp[1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# set the training and validata data sets from the temp data at read-in\n",
+ "# use the sklearn train_test_split function\n",
+ "x_tra, x_val, y_tra, y_val = train_test_split(train_temp[0], train_temp[1],\n",
+ " test_size=fraction_validation,\n",
+ " random_state=1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "e8Ab7Yx7v1CY"
+ },
+ "source": [
+ "### 2.3. Normalize data\n",
+ "\n",
+ "First, we make sure that the input data are floats. This allows us to perform computations on the real number line for the inputs.\n",
+ "\n",
+ "Second, we normalize the data according to the maximum value in all the data sets. The inputs will all exist on a smaller range. This improves the stability of the training."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ci1q1O8hv8U6"
+ },
+ "outputs": [],
+ "source": [
+ "# set to floats\n",
+ "x_tra = x_tra.astype('float32')\n",
+ "x_val = x_val.astype('float32')\n",
+ "x_tes = x_tes.astype('float32')\n",
+ "\n",
+ "# calculate min and max across all input images\n",
+ "input_minimum = np.min([np.min(x_tra), np.min(x_val), np.min(x_tes)])\n",
+ "input_maximum = np.max([np.max(x_tra), np.max(x_val), np.max(x_tes)])\n",
+ "\n",
+ "print(\"Before\")\n",
+ "print(\"min/max\", np.min(x_tra), np.max(x_tra))\n",
+ "\n",
+ "x_tra = normalizeInputs(x_tra, input_minimum, input_maximum)\n",
+ "x_val = normalizeInputs(x_val, input_minimum, input_maximum)\n",
+ "x_tes = normalizeInputs(x_tes, input_minimum, input_maximum)\n",
+ "\n",
+ "print(\"After\")\n",
+ "print(\"min/max\", np.min(x_tra), np.max(x_tra))\n",
+ "\n",
+ "# get shapes\n",
+ "image_shape = x_tra[0, :, :].shape"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "CdwlTbFafOYc"
+ },
+ "source": [
+ "### 2.4. Examine Raw Data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "CdwlTbFafOYc"
+ },
+ "source": [
+ "Review data shapes. \n",
+ "\n",
+ "The zeroth elements of the `x` and `y` shapes should match. The first and second elements of `x` should be equal: these are the dimensions of the images. The image size, in part determines the depth of the neural network that can be created."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Print the data shapes to make sure you understand how many objects there are and what the number of pixels is for each image."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 484
+ },
+ "id": "u3nrxn00fRUE",
+ "outputId": "5bac7964-621b-4984-be23-8e711bd00dc4"
+ },
+ "outputs": [],
+ "source": [
+ "print('check data shapes')\n",
+ "print('x_train:', x_tra.shape)\n",
+ "print('y_train:', y_tra.shape)\n",
+ "print('x_valid:', x_val.shape)\n",
+ "print('y_valid:', y_val.shape)\n",
+ "print('x_test:', x_tes.shape)\n",
+ "print('y_test:', y_tes.shape)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot examples to gain visual familiarity. Do these all look like hand-written digits?"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 484
+ },
+ "id": "u3nrxn00fRUE",
+ "outputId": "5bac7964-621b-4984-be23-8e711bd00dc4"
+ },
+ "outputs": [],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"Example_Image_Array\"\\\n",
+ " + \"_\" + path_dict['run_label']\n",
+ "plotArrayImageExamples(x_tra, y_tra,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot pixel distributions to further understand data. Is it normalized? Do the disributions of the pixel values make sense according to what you see in the related images above? \n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"Example_Histogram_Array\"\\\n",
+ " + \"_\" + path_dict['run_label']\n",
+ "\n",
+ "plotArrayHistogramExamples(x_tra, y_tra,\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "rmAXXIHOfUnD"
+ },
+ "source": [
+ "## 3. Train Model: Dense Neural Network"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2024-07-14T14:21:02.040978Z",
+ "iopub.status.busy": "2024-07-14T14:21:02.040269Z",
+ "iopub.status.idle": "2024-07-14T14:21:02.043255Z",
+ "shell.execute_reply": "2024-07-14T14:21:02.042798Z",
+ "shell.execute_reply.started": "2024-07-14T14:21:02.040956Z"
+ },
+ "id": "bKRmx2k2wNtE"
+ },
+ "source": [
+ "### 3.1. Define Model Training Parameters"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2024-07-14T14:21:02.040978Z",
+ "iopub.status.busy": "2024-07-14T14:21:02.040269Z",
+ "iopub.status.idle": "2024-07-14T14:21:02.043255Z",
+ "shell.execute_reply": "2024-07-14T14:21:02.042798Z",
+ "shell.execute_reply.started": "2024-07-14T14:21:02.040956Z"
+ },
+ "id": "bKRmx2k2wNtE"
+ },
+ "source": [
+ "Define optimizer\n",
+ "Define loss\n",
+ "Define accuracy\n",
+ "Define batch_size\n",
+ "Define epochs\n",
+ "Define metrics"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "tlgJk3oLwRnh"
+ },
+ "outputs": [],
+ "source": [
+ "epochs = 10\n",
+ "batch_size = 32\n",
+ "verbose = True\n",
+ "optimizer = \"sgd\"\n",
+ "loss = tf.keras.losses.SparseCategoricalCrossentropy()\n",
+ "metrics = ['accuracy']\n",
+ "dropout_rate = 0.3\n",
+ "learning_rate = 0.01\n",
+ "momentum = 0.9\n",
+ "seed = 1000"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Set the random seed for neural network weight initialization"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tf.keras.utils.set_random_seed(seed)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "rmAXXIHOfUnD"
+ },
+ "source": [
+ "### 3.2. Define Model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "rmAXXIHOfUnD"
+ },
+ "source": [
+ "Define Sequential Model\n",
+ "Define layers\n",
+ "Define flat layer\n",
+ "Define dense layers\n",
+ "Define activation function; define types activation functions -- sigmoid and relu\n",
+ "Define weights and biases"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "U6vymJn1wJBu"
+ },
+ "outputs": [],
+ "source": [
+ "model_layers = [tf.keras.layers.Input(shape=image_shape),\n",
+ " tf.keras.layers.Flatten(),\n",
+ " tf.keras.layers.Dense(256, activation='sigmoid'),\n",
+ " tf.keras.layers.Dense(64, activation='sigmoid'),\n",
+ " tf.keras.layers.Dropout(dropout_rate),\n",
+ " tf.keras.layers.Dense(10, activation='softmax')]\n",
+ "\n",
+ "model = tf.keras.models.Sequential(model_layers)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "U6vymJn1wJBu"
+ },
+ "source": [
+ "View a summary of the network architecture. Examine the shapes of the layers and the numbers of parameters. Too few parameters may prevent the model from being flexible enough to model the data. Too many parameters could lead to overfitting of the model and a high computational cost."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "U6vymJn1wJBu"
+ },
+ "outputs": [],
+ "source": [
+ "model.summary()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2024-07-14T14:21:02.040978Z",
+ "iopub.status.busy": "2024-07-14T14:21:02.040269Z",
+ "iopub.status.idle": "2024-07-14T14:21:02.043255Z",
+ "shell.execute_reply": "2024-07-14T14:21:02.042798Z",
+ "shell.execute_reply.started": "2024-07-14T14:21:02.040956Z"
+ },
+ "id": "bKRmx2k2wNtE"
+ },
+ "source": [
+ "### 3.3. Compile and Train Model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Compile the model with the model settings created earlier."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "YwLTHCccwTeo",
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "model.compile(optimizer=optimizer, loss=loss, metrics=metrics)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Train (fit) the model. The output `history` contains the loss value of the training data and the validation data for each epoch."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "history = model.fit(x_tra, y_tra,\n",
+ " batch_size=batch_size,\n",
+ " epochs=epochs,\n",
+ " validation_data=(x_val, y_val),\n",
+ " verbose=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Save the model as a `.keras` zip archive so that it can be used later -- e.g., for comparison to other models."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "file_prefix = path_dict['file_model_prefix'] + \"_\" + path_dict['run_label']\n",
+ "file_name_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_model'],\n",
+ " file_suffix=path_dict['file_model_suffix'],\n",
+ " useuid=True,\n",
+ " verbose=True)\n",
+ "\n",
+ "model.save(file_name_final)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "suHSArn6wb27"
+ },
+ "source": [
+ "## 4. Diagnosing the Results of the Classification Model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.1. Key Terms for Diagnostic Metrics\n",
+ "\n",
+ "We use the following diagnostics to assess the status of the network optimization and efficacy. \n",
+ "https://scikit-learn.org/stable/modules/model_evaluation.html\n",
+ "\n",
+ "\n",
+ "* **Metrics**\n",
+ " * Loss:\n",
+ " * Accuracy: Use as a rough indicator of model training progress/convergence for balanced datasets. For model performance, use only in combination with other metrics. Avoid for imbalanced datasets. Consider using another metric.\n",
+ " * tpr (Recall): Use when false negatives are more expensive than false positives.\n",
+ " * for: Use when false positives are more expensive than false negatives.\n",
+ " * precision: Use when it's very important for positive predictions to be accurate.\n",
+ " * \n",
+ "* **Generalization Error**: The Generalization Error (GE) is the difference in loss when the model is applied to training data versus when it is applied to validation data and test data.\n",
+ "* **Confusion Matrix**:\n",
+ "* **Receiver Operator Characteristic (ROC) Curve**:\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.2. Classification Predictions"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Predict classification probabilities on the training, validation and test sets."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "y_pred_tra = model.predict(x_tra, verbose=True)\n",
+ "y_pred_val = model.predict(x_val, verbose=True)\n",
+ "y_pred_tes = model.predict(x_tes, verbose=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Identify what the top-choice class is for each object in the training, validation and test sets."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "y_pred_tra_topchoice = y_pred_tra.argmax(axis=1)\n",
+ "y_pred_val_topchoice = y_pred_val.argmax(axis=1)\n",
+ "y_pred_tes_topchoice = y_pred_tes.argmax(axis=1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"10 probabilities for each object:\", np.shape(y_pred_tra))\n",
+ "print(\"Top choice for each object:\", np.shape(y_pred_tra_topchoice))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Histograms of prediction distributions by class"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"Histograms_top_choice\"\\\n",
+ " + \"_\" + path_dict['run_label']\n",
+ "plotPredictionHistogram(y_pred_tra_topchoice,\n",
+ " y_prediction_b=y_pred_val_topchoice,\n",
+ " y_prediction_c=y_pred_tes_topchoice,\n",
+ " label_a=\"Training Set\",\n",
+ " label_b=\"Validation Set\",\n",
+ " label_c=\"Testing Set\",\n",
+ " figsize=(12, 5),\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"Histograms_class_probabilities\"\\\n",
+ " + \"_\" + path_dict['run_label']\n",
+ "plotPredictionHistogram(y_pred_tra,\n",
+ " y_prediction_b=y_pred_val,\n",
+ " y_prediction_c=y_pred_tes,\n",
+ " title_a='Training Set',\n",
+ " title_b='Validation Set',\n",
+ " title_c='Testing Set',\n",
+ " figsize=(15, 4),\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Observations about these histograms ...\n",
+ "1. very similar shapes across the data sets: that's good"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "jp-MarkdownHeadingCollapsed": true
+ },
+ "source": [
+ "### 4.3. Generalization Error\n",
+ "\n",
+ "The primary task in optimizing a network is to minimize the Generalization Error. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.3.1. Loss History: History of Loss and Accuracy during Training"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot the loss history for the validation and training sets. We reserve the test set for a 'blind' analysis."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "EIiJTdK-weWf"
+ },
+ "outputs": [],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"LossHistory\"\\\n",
+ " + \"_\"\\\n",
+ " + path_dict['run_label']\n",
+ "plotLossHistory(history,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.3.2. Confusion Matrix: Bias in Trained Model?"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Compute confusion matrices"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "cm_tra = confusion_matrix(y_pred_tra_topchoice, y_tra)\n",
+ "cm_val = confusion_matrix(y_pred_val_topchoice, y_val)\n",
+ "cm_tes = confusion_matrix(y_pred_tes_topchoice, y_tes)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "plot confusion matrices for training, validation, and test samples (left, right, middle)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plotConfusionMatrix(cm_tra, cm_val, cm_tes)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.3.4. Investigating Errant Classifications: Look at the examples\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Choose a digit/class (human option/choice) for examination."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class_value = 2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Find all objects that have that class value. \n",
+ "Obtain indices for the true positives (tp's), false positives (fp's), true negatives (tn's), and false negatives (fn's)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ind_class_tp_tra = np.where((y_tra == class_value)\n",
+ " & (y_pred_tra_topchoice == class_value))[0]\n",
+ "\n",
+ "ind_class_fp_tra = np.where((y_tra != class_value)\n",
+ " & (y_pred_tra_topchoice == class_value))[0]\n",
+ "\n",
+ "ind_class_tn_tra = np.where((y_tra != class_value)\n",
+ " & (y_pred_tra_topchoice != class_value))[0]\n",
+ "\n",
+ "ind_class_fn_tra = np.where((y_tra == class_value)\n",
+ " & (y_pred_tra_topchoice != class_value))[0]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "plot examples of false positives"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"ExampleImages_TruePostives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ "plotArrayImageConfusion(x_tra[ind_class_tp_tra],\n",
+ " y_tra[ind_class_tp_tra],\n",
+ " y_pred_tra_topchoice[ind_class_tp_tra],\n",
+ " title_main=\"True Positives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"ExampleImages_FalsePostives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ "plotArrayImageConfusion(x_tra[ind_class_fp_tra],\n",
+ " y_tra[ind_class_fp_tra],\n",
+ " y_pred_tra_topchoice[ind_class_fp_tra],\n",
+ " title_main=\"False Positives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"ExampleImages_TrueNegatives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ "plotArrayImageConfusion(x_tra[ind_class_tn_tra],\n",
+ " y_tra[ind_class_tn_tra],\n",
+ " y_pred_tra_topchoice[ind_class_tn_tra],\n",
+ " title_main=\"True Negatives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"ExampleImages_FalseNegatives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ "plotArrayImageConfusion(x_tra[ind_class_fn_tra],\n",
+ " y_tra[ind_class_fn_tra],\n",
+ " y_pred_tra_topchoice[ind_class_fn_tra],\n",
+ " title_main=\"False Negatives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot histograms of images pixels of true positives, false positives, true negatives, and false negatives."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix']\\\n",
+ " + \"_\" + \"ExampleImages_TruePostives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ "plotArrayHistogramConfusion(x_tra[ind_class_tp_tra],\n",
+ " y_tra[ind_class_tp_tra],\n",
+ " y_pred_tra_topchoice[ind_class_tp_tra],\n",
+ " title_main=\"True Positives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "file_prefix = path_dict['file_figure_prefix']\\\n",
+ " + \"_\" + \"ExampleImages_FalsePostives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ "plotArrayHistogramConfusion(x_tra[ind_class_fp_tra],\n",
+ " y_tra[ind_class_fp_tra],\n",
+ " y_pred_tra_topchoice[ind_class_fp_tra],\n",
+ " title_main=\"False Positives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "file_prefix = path_dict['file_figure_prefix']\\\n",
+ " + \"_\" + \"ExampleImages_TrueNegatives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ "plotArrayHistogramConfusion(x_tra[ind_class_tn_tra],\n",
+ " y_tra[ind_class_tn_tra],\n",
+ " y_pred_tra_topchoice[ind_class_tn_tra],\n",
+ " title_main=\"True Negatives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "file_prefix = path_dict['file_figure_prefix']\\\n",
+ " + \"_\" + \"ExampleImages_FalseNegatives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ "plotArrayHistogramConfusion(x_tra[ind_class_fn_tra],\n",
+ " y_tra[ind_class_fn_tra],\n",
+ " y_pred_tra_topchoice[ind_class_fn_tra],\n",
+ " title_main=\"False Negatives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 5. Exercises for the Learner"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Each time you train a new model, re-run all the diagnostic plots.\n",
+ "\n",
+ "1. How do the loss and accuracy histories change when batch size is small or large? Why?\n",
+ "2. Does the NN take more or less time (more or fewer epochs) to converge if the input image data are normalized or not normalized? Why?\n",
+ "3. How does the size of the training set affect the model's accuracy and loss -- keeping the number of epochs the same? Why?\n",
+ "3. How does the random seed for the weight initialization affect the model's accuracy and loss -- keeping the number of epochs the same?\n",
+ "5. Use the `time` module to estimate the time for the model fitting. Record that time. Increase and then decrease the number of weights in the NN by an order of magnitude. Train the NN for each of those models and record the times. How does the number of weights in the neural network affect the training time and model loss and accuracy?\n",
+ "6. Use the `time` module to estimate the time for the model fitting. Record that time. Increase and then decrease the number of layers in the NN. Train the NN for each of those models and record the times. How does the number of weights in the neural network affect the training time and model loss and accuracy?\n",
+ "7. Use the `time` module to estimate the time for the model fitting. Record that time. Add a convolutional layer to the NN. Train the NN for each of those models and record the times. How does the number of weights in the neural network affect the training time and model loss and accuracy?"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 6. Glossary of neural network terms\n",
+ "\n",
+ "1. network weights\n",
+ "2. deep learning\n",
+ "3. machine learning\n",
+ "4. learning\n",
+ "5. activation function\n",
+ "6. pool(ing)\n",
+ "7. convolution\n",
+ "8. layer\n",
+ "9. loss function\n",
+ "10. confusion matrix\n",
+ "11. epoch\n",
+ "12. batch size\n",
+ "13. learning rate\n",
+ "14. momentum\n",
+ "15. stochastic gradient descent\n",
+ "16. optimizer\n",
+ "17. receiver operator characteristic (ROC)\n",
+ "18. area under the curve (AUC)\n",
+ "19. training\n",
+ "20. validation\n",
+ "21. testing\n",
+ "22. class\n",
+ "23. hyperparameter (vs. parameter)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 7. AI is math, not magic.\n",
+ "\n",
+ "AI is firmly based in math, computer science, and statistics. Additionally, some of the approaches are inspired by concepts or notions in biology (e.g., the computational neuron) and in physics (e.g., the reverse Boltzmann machine). \n",
+ "\n",
+ "Much of the jargon in AI is anthropomorphic, which can make it appear that some other than math is happening. For example, consider the following list of terms that are very often used in AI -- and what these terms actually mean mathematically.\n",
+ "\n",
+ "1. learn $\\rightarrow$ fit\n",
+ "2. hallucinate/lie $\\rightarrow$ predict incorrectly\n",
+ "3. understand $\\rightarrow$ model has converged\n",
+ "4. cheat $\\rightarrow$ more efficiently guesses the best weight parameters of the model\n",
+ "5. believe $\\rightarrow$ predict/infer based on statistical priors\n",
+ "\n",
+ "When we over-anthropomorphize this mathematical tool, we obfuscate how it actually works, and that makes it harder to build and refine models. That is, AI models are not 'learning' or 'understanding'; they are large-parameter models that are being fit to data. The only learning that's happening is what we do with these models."
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "LSST",
+ "language": "python",
+ "name": "lsst"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
From 6e3d13758ec4f4404ff154b3301cc7539b2f792b Mon Sep 17 00:00:00 2001
From: MelissaGraham
Date: Mon, 25 Nov 2024 20:17:48 +0000
Subject: [PATCH 2/5] add 16a
---
README.md | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 8ec5f12..9307751 100644
--- a/README.md
+++ b/README.md
@@ -36,6 +36,7 @@ Tutorial titles in **bold** have Spanish-language versions.
| 13a. Using The Image Cutout Tool With DP0.2 | Demonstration of the use of the image cutout tool with a few science applications. |
| 14. Injecting Synthetic Sources Into Single-Visit Images | Inject artificial stars and galaxies into images. |
| 15. Survey Property Maps | Use the tools to visualize full-area survey property maps. |
+| 16a. Introduction to Tensorflow | Learn to classify images with AI-based classification algorithms. |
## DP0.3 Tutorials
@@ -119,7 +120,7 @@ The *content* of these notebooks are licensed under the Apache 2.0 License. Tha
| **13a. Using The Image Cutout Tool With DP0.2** **Level:** Beginner **Description:** This notebook demonstrates how to use the Rubin Image Cutout Service. **Skills:** Run the Rubin Image Cutout Service for visual inspection of small cutouts of LSST images. **Data Products:** Images (deepCoadd, calexp), catalogs (objectTable, diaObject, truthTables, ivoa.ObsCore). **Packages:** PyVO, lsst.rsp.get_tap_service, lsst.pipe.tasks.registerImage, lsst.afw.display |
| **14. Injecting Synthetic Sources Into Single-Visit Images** **Level:** Advanced **Description:** This tutorial demonstrates a method to inject artificial sources (stars and galaxies) into calexp images using the measured point-spread function of the given calexp image. Confirmation that the synthetic sources were correctly injected into the image is done by running a difference imaging task from the pipelines. **Skills:** Use the `source_injection` tools to inject synthetic sources into images. Create a difference image from a `calexp` with injected sources. **Data Products:** Butler calexp images and corresponding src catalogs, goodSeeingDiff_templateExp images, and injection_catalogs. **Packages:** lsst.source.injection |
| **15. Survey Property Maps** **Level:** Intermediate **Description:** Use the tools to visualize full-area survey property maps. **Skills:** Load and visualize survey property maps using healsparse and skyproj. **Data Products:** Survey property maps. **Packages:** healsparse, skyproj, lsst.daf.butler |
-
+| **16a. Introduction to Tensorflow** **Level:** Beginner **Description:** An introduction to the classification of images with AI-based classification algorithms. **Skills:** Examine AI training data, prepare it for a classification task, perform classification with a neural network, and examine the diagnostics of the classification task. **Data Products:** MNIST data. **Packages:** sklearn, tensorflow |
| Skills in **DP0.3** Tutorial Notebooks |
|---|
From fab6265ebcbf5b8c9b20b759190b2dff59efea1d Mon Sep 17 00:00:00 2001
From: MelissaGraham
Date: Tue, 26 Nov 2024 00:28:51 +0000
Subject: [PATCH 3/5] MLG edits
---
...geClassificationWithTensorflow_Draft.ipynb | 293 ++++++++++++------
1 file changed, 192 insertions(+), 101 deletions(-)
diff --git a/AI0_Intro_AI_ImageClassificationWithTensorflow_Draft.ipynb b/AI0_Intro_AI_ImageClassificationWithTensorflow_Draft.ipynb
index 9939bcd..e7aafd0 100644
--- a/AI0_Intro_AI_ImageClassificationWithTensorflow_Draft.ipynb
+++ b/AI0_Intro_AI_ImageClassificationWithTensorflow_Draft.ipynb
@@ -7,10 +7,10 @@
},
"source": [
" \n",
- " AI0: Introduction to AI-based Image Classification with Tensorflow \n",
+ " Introduction to AI-based Image Classification with Tensorflow \n",
"Contact author: Brian Nord \n",
- "Last verified to run: YYYY-MM-DD \n",
- "LSST Science Pipelines version: ?? \n",
+ "Last verified to run: 2024-11-25 \n",
+ "LSST Science Pipelines version: Weekly 2024_42 \n",
"Container size: medium \n",
"Targeted learning level: beginner "
]
@@ -76,13 +76,25 @@
"\n",
"AI is a class of algorithms for building statistical models. These algorithms primarily use data for training, as opposed to models that use analytic formulae or models that are based on physical reasoning. Machine learning is a subclass of algorithms -- e.g., random forests. Deep learning is a subclass of algorithms -- e.g., neural networks. \n",
"\n",
- "This notebook uses `tensorflow`, one of the two most commonly used `python` libraries for deep learning. `Tensorflow` is often easier to use because of how it handles data sets and the logic used for model building. However, it is typically also difficult to develop network models creatively. We use `tensorflow` first in this series of tutorials so that users who are new to deep learning can focus on learning AI. In later tutorials, we will use `pytorch` because it is more flexible and more commonly used in science applications. \n",
+ "This notebook uses `tensorflow`, one of the two most commonly used `python` libraries for deep learning. `Tensorflow` is often easier to use because of how it handles data sets and the logic used for model building. However, it is typically also difficult to develop network models creatively. This tutorial is the first in a series, and uses `tensorflow` so that users who are new to deep learning can focus on learning AI. In later tutorials, `pytorch` will be used because it is more flexible and more commonly used in science applications. \n",
"\n",
- "This notebook uses [MNIST AI benchmarking data](https://en.wikipedia.org/wiki/MNIST_database). In a future notebook, we will we'll use stars and galaxies drawn from DP0 data.\n",
+ "Instead of using DP0 data, this tutorials uses [MNIST AI benchmarking data](https://en.wikipedia.org/wiki/MNIST_database), a large database of handwritten digits that is commonly used for training and testing machine learning algorithms. It is simple to understand, so that users who are new to deep learning can focus on learning AI. Later tutorials in this series will use stars and galaxies drawn from DP0 data.\n",
"\n",
- "The use of data in this notebook requires a medium-sized ram allocation (8Gi).\n",
+ "### 1.1. AI is math, not magic.\n",
"\n",
- "The end of this notebook contains a Glossary of Terms and a comment regarding usage of terms in AI contexts."
+ "AI is firmly based in math, computer science, and statistics. Additionally, some of the approaches are inspired by concepts or notions in biology (e.g., the computational neuron) and in physics (e.g., the reverse Boltzmann machine). \n",
+ "\n",
+ "Much of the jargon in AI is anthropomorphic, which can make it appear that some other than math is happening. For example, consider the following list of terms that are very often used in AI -- and what these terms actually mean mathematically.\n",
+ "\n",
+ "1. learn $\\rightarrow$ fit\n",
+ "2. hallucinate/lie $\\rightarrow$ predict incorrectly\n",
+ "3. understand $\\rightarrow$ model has converged\n",
+ "4. cheat $\\rightarrow$ more efficiently guesses the best weight parameters of the model\n",
+ "5. believe $\\rightarrow$ predict/infer based on statistical priors\n",
+ "\n",
+ "When we over-anthropomorphize this mathematical tool, we obfuscate how it actually works, and that makes it harder to build and refine models. That is, AI models are not 'learning' or 'understanding'; they are large-parameter models that are being fit to data. The only learning that's happening is what we do with these models.\n",
+ "\n",
+ "The end of this notebook contains a glossary of AI-related terms."
]
},
{
@@ -103,7 +115,7 @@
"id": "V3xHhKu6c5-e"
},
"source": [
- "### 1.1. Import Packages\n",
+ "### 1.2. Import packages\n",
"\n",
"[`numpy`](https://numpy.org/) is a widely used Python library for computations and mathematical operations on multi-dimensional arrays.\n",
"\n",
@@ -141,7 +153,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### 1.2 Define Functions"
+ "### 1.3. Define functions\n",
+ "\n",
+ "The following functions are defined and used throughout this notebook.\n",
+ "\n",
+ "It is not necessary to understand exactly what every funtion does in order to proceed with this tutorial.\n",
+ "\n",
+ "Execute all cells and move on to Section 2."
]
},
{
@@ -539,7 +557,7 @@
" n_objects_a = shape_a[0]\n",
"\n",
" if ndim == 2:\n",
- " if n_classes == None:\n",
+ " if n_classes is None:\n",
" n_classes = shape_a[1]\n",
" if n_colors is None:\n",
" n_colors = n_classes\n",
@@ -871,14 +889,19 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### 1.3 Define Paths for Data and Plots"
+ "### 1.4. Define paths for data and plots"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Neural network training (i.e., model fitting) typically requires many numerical experiments to achieve an ideal model. To facilitate the comparison of these experiments/models, it is helpful to organize data carefully. We set paths for the model weight parameters and diagnostic figures. We also set the variable `run_label` for each training run. We also save these paths in a dictionary to facilitate passing information to plotting functions."
+ "Neural network training (i.e., model fitting) typically requires many numerical experiments to achieve an ideal model. To facilitate the comparison of these experiments/models, it is helpful to organize data carefully. \n",
+ "\n",
+ "Set the variable `run_label` for each training run. \n",
+ "Set paths for the model weight parameters and diagnostic figures, and\n",
+ "save these paths in the dictionary `path_dict` to facilitate passing information to plotting functions.\n",
+ "Check whether the paths exist, and if not, create them."
]
},
{
@@ -889,14 +912,16 @@
"source": [
"run_label = \"Run000\"\n",
"\n",
+ "temppath = os.getenv(\"HOME\") + '/dp02_16a_temp/'\n",
"path_dict = {'run_label': run_label,\n",
- " 'dir_data_model': \"Data/Models/\",\n",
- " 'dir_data_figures': \"Data/Figures/\",\n",
+ " 'dir_data_model': temppath + \"Data/Models/\",\n",
+ " 'dir_data_figures': temppath + \"Data/Figures/\",\n",
" 'file_model_prefix': \"Model\",\n",
" 'file_figure_prefix': \"Figure\",\n",
" 'file_figure_suffix': \".png\",\n",
" 'file_model_suffix': \".keras\"\n",
" }\n",
+ "del temppath\n",
"\n",
"if not os.path.exists(path_dict['dir_data_model']):\n",
" os.makedirs(path_dict['dir_data_model'])\n",
@@ -911,7 +936,7 @@
"id": "jce50kKEfHC1"
},
"source": [
- "## 2. Load and Prepare data: MNIST Handwritten Digits"
+ "## 2. Load and prepare data: MNIST Handwritten Digits"
]
},
{
@@ -920,7 +945,7 @@
"id": "-dHXIDGEfLmO"
},
"source": [
- "### 2.1. Download Dataset"
+ "### 2.1. Download the dataset"
]
},
{
@@ -929,9 +954,11 @@
"id": "-dHXIDGEfLmO"
},
"source": [
- "The [`MNIST handwritten digits dataset`](https://ieeexplore.ieee.org/document/6296535) comprises 10 classes --- one for each digit. This is a useful dataset for learning the basics of neural networks and other AI algorithms. MNIST is one of a few canonical AI benchmark data sets for image classification. `tensorflow` has a simple function easily downloading the MNIST data to your local server for free. It automatically downloads the data into.\n",
+ "The [MNIST handwritten digits dataset](https://ieeexplore.ieee.org/document/6296535) comprises 10 classes --- one for each digit. This is a useful dataset for learning the basics of neural networks and other AI algorithms. MNIST is one of a few canonical AI benchmark data sets for image classification. `Tensorflow` has a simple function easily downloading the MNIST data to your local server for free.\n",
"\n",
- "The **input** data are held in `x_`, while the **output** (aka, label) data are held in `y_`."
+ "Automatically download the data using `tf.keras.datasets.mnist`.\n",
+ "The `tf` class automatically downloads data into training and test data sets;\n",
+ "load them into variables `train_temp` and `test_temp`, respectively."
]
},
{
@@ -956,7 +983,7 @@
"id": "O4FPxkLKiJKe"
},
"source": [
- "### 2.2. Split Data into Train/Validation/Test"
+ "### 2.2. Split data into training, validation, and testing"
]
},
{
@@ -973,7 +1000,7 @@
"* **Validation** (`_val`) data is used indirectly to update the hyperparameters of the AI model -- e.g., the batchsize, the learning rate, or the layers in the architecture of a neural network. Each time the neural network has completed training with the training data, the human looks at those diagnostics when run on the training and the validation data.\n",
"* **Test(ing)** (`_tes`) data is only used when the model is trained and validated and will no longer be update or further trained. \n",
"\n",
- "The `TF` class automatically downloads data into training and test data sets. Therefore, we use the `sklearn` `train_test_split()` function to further split the training set into training and validation data sets. We then \n"
+ "Set the fraction of the data set to use for validation as 25%."
]
},
{
@@ -987,24 +1014,37 @@
"fraction_validation = 0.25"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The **input** data are held in `x_`, while the **output** (aka, label) data are held in `y_`.\n",
+ "\n",
+ "Set the test data sets from `test_temp`, which was loaded above."
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "# set the test data sets from the temp data at read-in\n",
"x_tes, y_tes = test_temp[0], test_temp[1]"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use the `sklearn` `train_test_split()` function to further split the training set into training and validation data sets."
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "# set the training and validata data sets from the temp data at read-in\n",
- "# use the sklearn train_test_split function\n",
"x_tra, x_val, y_tra, y_val = train_test_split(train_temp[0], train_temp[1],\n",
" test_size=fraction_validation,\n",
" random_state=1)"
@@ -1018,9 +1058,10 @@
"source": [
"### 2.3. Normalize data\n",
"\n",
- "First, we make sure that the input data are floats. This allows us to perform computations on the real number line for the inputs.\n",
+ "Make sure that the input data are type float, to enable computations on the real number line for the inputs.\n",
"\n",
- "Second, we normalize the data according to the maximum value in all the data sets. The inputs will all exist on a smaller range. This improves the stability of the training."
+ "Calculate the minimum and maximum value across all data sub-sets.\n",
+ "Use the `normalizeInputs` function to normalize the data according to the minimum and maximum value in all the data sets. The inputs will all exist on a smaller range. This improves the stability of the training."
]
},
{
@@ -1031,27 +1072,20 @@
},
"outputs": [],
"source": [
- "# set to floats\n",
"x_tra = x_tra.astype('float32')\n",
"x_val = x_val.astype('float32')\n",
"x_tes = x_tes.astype('float32')\n",
"\n",
- "# calculate min and max across all input images\n",
"input_minimum = np.min([np.min(x_tra), np.min(x_val), np.min(x_tes)])\n",
"input_maximum = np.max([np.max(x_tra), np.max(x_val), np.max(x_tes)])\n",
"\n",
- "print(\"Before\")\n",
- "print(\"min/max\", np.min(x_tra), np.max(x_tra))\n",
+ "print(\"Before normalization, min and max: \", np.min(x_tra), np.max(x_tra))\n",
"\n",
"x_tra = normalizeInputs(x_tra, input_minimum, input_maximum)\n",
"x_val = normalizeInputs(x_val, input_minimum, input_maximum)\n",
"x_tes = normalizeInputs(x_tes, input_minimum, input_maximum)\n",
"\n",
- "print(\"After\")\n",
- "print(\"min/max\", np.min(x_tra), np.max(x_tra))\n",
- "\n",
- "# get shapes\n",
- "image_shape = x_tra[0, :, :].shape"
+ "print(\"After normalization, min and max: \", np.min(x_tra), np.max(x_tra))"
]
},
{
@@ -1060,7 +1094,7 @@
"id": "CdwlTbFafOYc"
},
"source": [
- "### 2.4. Examine Raw Data"
+ "### 2.4. Examine raw data"
]
},
{
@@ -1071,14 +1105,27 @@
"source": [
"Review data shapes. \n",
"\n",
- "The zeroth elements of the `x` and `y` shapes should match. The first and second elements of `x` should be equal: these are the dimensions of the images. The image size, in part determines the depth of the neural network that can be created."
+ "The variable `image_shape` will be used again in Section 3.2."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"The shape of the input training set is:\")\n",
+ "image_shape = x_tra[0, :, :].shape\n",
+ "print(image_shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Print the data shapes to make sure you understand how many objects there are and what the number of pixels is for each image."
+ "The zeroth elements of the `x` and `y` shapes should match. The first and second elements of `x` should be equal: these are the dimensions of the images. The image size, in part determines the depth of the neural network that can be created.\n",
+ "\n",
+ "Print the data shapes to understand how many objects there are and what the number of pixels is for each image."
]
},
{
@@ -1094,7 +1141,6 @@
},
"outputs": [],
"source": [
- "print('check data shapes')\n",
"print('x_train:', x_tra.shape)\n",
"print('y_train:', y_tra.shape)\n",
"print('x_valid:', x_val.shape)\n",
@@ -1107,7 +1153,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Plot examples to gain visual familiarity. Do these all look like hand-written digits?"
+ "Plot examples to gain visual familiarity. All images should look like hand-written digits."
]
},
{
@@ -1136,7 +1182,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Plot pixel distributions to further understand data. Is it normalized? Do the disributions of the pixel values make sense according to what you see in the related images above? \n"
+ "> Figure 1: Two rows of five images, each a handwritten number in white on a black background.\n",
+ "\n",
+ "Plot the distributions of pixel values to further understand data.\n",
+ "\n",
+ "Note that all pixel data has been normalized, and pixels have values between 0 and 1 only.\n",
+ "\n",
+ "Note that the distribution of pixel values matches the images shown above: mostly black (values near 0) pixels, with some white (values near 1), and a few grey (values in between 0 and 1)."
]
},
{
@@ -1162,7 +1214,9 @@
"id": "rmAXXIHOfUnD"
},
"source": [
- "## 3. Train Model: Dense Neural Network"
+ "> Figure 2: Two rows of five plots, each showing the distribution of pixel values (number of pixels of a given value) for the handwritten digit images shown in Figure 1.\n",
+ "\n",
+ "## 3. Train the model: dense neural network"
]
},
{
@@ -1178,7 +1232,7 @@
"id": "bKRmx2k2wNtE"
},
"source": [
- "### 3.1. Define Model Training Parameters"
+ "### 3.1. Define model training parameters"
]
},
{
@@ -1194,12 +1248,17 @@
"id": "bKRmx2k2wNtE"
},
"source": [
- "Define optimizer\n",
- "Define loss\n",
- "Define accuracy\n",
- "Define batch_size\n",
- "Define epochs\n",
- "Define metrics"
+ " * `epochs` : \n",
+ " * `batch_size` : \n",
+ " * `verbose` : When `True`, the code will write more output to screen. This can help users diagnose issues.\n",
+ " * `optimizer` : \n",
+ " * `loss` : \n",
+ " * `metrics` : \n",
+ " * `dropout_rate` : \n",
+ " * `learning_rate` : \n",
+ " * `momentum` :\n",
+ "\n",
+ "Define the variables to hold the model training parameters."
]
},
{
@@ -1218,15 +1277,14 @@
"metrics = ['accuracy']\n",
"dropout_rate = 0.3\n",
"learning_rate = 0.01\n",
- "momentum = 0.9\n",
- "seed = 1000"
+ "momentum = 0.9"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Set the random seed for neural network weight initialization"
+ "Set the random seed for neural network weight initialization."
]
},
{
@@ -1235,6 +1293,7 @@
"metadata": {},
"outputs": [],
"source": [
+ "seed = 1000\n",
"tf.keras.utils.set_random_seed(seed)"
]
},
@@ -1244,7 +1303,7 @@
"id": "rmAXXIHOfUnD"
},
"source": [
- "### 3.2. Define Model"
+ "### 3.2. Define the model"
]
},
{
@@ -1253,12 +1312,16 @@
"id": "rmAXXIHOfUnD"
},
"source": [
- "Define Sequential Model\n",
- "Define layers\n",
- "Define flat layer\n",
- "Define dense layers\n",
- "Define activation function; define types activation functions -- sigmoid and relu\n",
- "Define weights and biases"
+ " * layers : \n",
+ " * flat layer : \n",
+ " * dense layers : \n",
+ " * activation function : \n",
+ " * sigmoid : \n",
+ " * softmax : \n",
+ " * weights and biases : \n",
+ " * sequential model :\n",
+ "\n",
+ "Define `model_layers` and use it to set `model` as a sequential model."
]
},
{
@@ -1312,7 +1375,7 @@
"id": "bKRmx2k2wNtE"
},
"source": [
- "### 3.3. Compile and Train Model"
+ "### 3.3. Compile and train the model"
]
},
{
@@ -1383,29 +1446,31 @@
"id": "suHSArn6wb27"
},
"source": [
- "## 4. Diagnosing the Results of the Classification Model"
+ "## 4. Diagnosing the results of the classification model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "### 4.1. Key Terms for Diagnostic Metrics\n",
+ "### 4.1. Key terms for diagnostic metrics\n",
+ "\n",
+ "Use the following diagnostics to assess the status of the network optimization and efficacy. \n",
+ "\n",
+ "A good reference is this [scikit-learn page on metrics and scoring](https://scikit-learn.org/stable/modules/model_evaluation.html).\n",
+ "\n",
+ "**Metrics**\n",
+ " * Loss:\n",
+ " * Accuracy: Use as a rough indicator of model training progress/convergence for balanced datasets. For model performance, use only in combination with other metrics. Avoid for imbalanced datasets. Consider using another metric.\n",
+ " * True Positive Rate (TPR; \"Recall\"): Use when false negatives are more expensive than false positives.\n",
+ " * for: Use when false positives are more expensive than false negatives.\n",
+ " * precision: Use when it's very important for positive predictions to be accurate.\n",
"\n",
- "We use the following diagnostics to assess the status of the network optimization and efficacy. \n",
- "https://scikit-learn.org/stable/modules/model_evaluation.html\n",
+ "**Generalization Error**: The Generalization Error (GE) is the difference in loss when the model is applied to training data versus when it is applied to validation data and test data.\n",
"\n",
+ "**Confusion Matrix**:\n",
"\n",
- "* **Metrics**\n",
- " * Loss:\n",
- " * Accuracy: Use as a rough indicator of model training progress/convergence for balanced datasets. For model performance, use only in combination with other metrics. Avoid for imbalanced datasets. Consider using another metric.\n",
- " * tpr (Recall): Use when false negatives are more expensive than false positives.\n",
- " * for: Use when false positives are more expensive than false negatives.\n",
- " * precision: Use when it's very important for positive predictions to be accurate.\n",
- " * \n",
- "* **Generalization Error**: The Generalization Error (GE) is the difference in loss when the model is applied to training data versus when it is applied to validation data and test data.\n",
- "* **Confusion Matrix**:\n",
- "* **Receiver Operator Characteristic (ROC) Curve**:\n"
+ "**Receiver Operator Characteristic (ROC) Curve**:\n"
]
},
{
@@ -1451,6 +1516,13 @@
"y_pred_tes_topchoice = y_pred_tes.argmax(axis=1)"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Verify that the shapes `y_pred_tra` match the length of 45000, as in Section 2.4."
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -1465,7 +1537,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Histograms of prediction distributions by class"
+ "Use the `plotPredictionHistogram` function to plot histograms of prediction distributions by class."
]
},
{
@@ -1486,8 +1558,22 @@
" figsize=(12, 5),\n",
" file_prefix=file_prefix,\n",
" file_location=path_dict['dir_data_figures'],\n",
- " file_suffix=path_dict['file_figure_suffix'])\n",
- "\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 3: Histograms of the number of images for which the top-choice class was each number 0 through 9, for the training, validation, and test sets. Note that these are overlapping histograms, not stacked."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
" + \"Histograms_class_probabilities\"\\\n",
" + \"_\" + path_dict['run_label']\n",
@@ -1507,8 +1593,10 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Observations about these histograms ...\n",
- "1. very similar shapes across the data sets: that's good"
+ "> Figure 4: Histograms of the number of images (y-axis) that had a probability (x-axis) of being each class 0 through 9 (light to dark shades).\n",
+ "\n",
+ "In both of Figure 3 and 4, the histograms show very similar shapes across the classification categories.\n",
+ "This is a good sign because it indicates the model does not 'prefer' a class."
]
},
{
@@ -1526,14 +1614,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### 4.3.1. Loss History: History of Loss and Accuracy during Training"
+ "### 4.3.1. Loss History: history of loss and accuracy during training"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Plot the loss history for the validation and training sets. We reserve the test set for a 'blind' analysis."
+ "Plot the loss history for the validation and training sets, but reserve the test set for a 'blind' analysis."
]
},
{
@@ -1558,14 +1646,16 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### 4.3.2. Confusion Matrix: Bias in Trained Model?"
+ "> Figure 5: In the top panel, the loss history as a function of epoch for the training and validation sets decreases with time, as it should as the model improves. In the bottom panel, the loss residual (validation - training) shows a dip at epoch 2, indicating the model caused a divergence in the training and validation set classifications, but that this was rectified in later epochs.\n",
+ "\n",
+ "### 4.3.2. Confusion matrix: look for bias in the trained model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Compute confusion matrices"
+ "Compute the confusion matrices."
]
},
{
@@ -1583,7 +1673,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "plot confusion matrices for training, validation, and test samples (left, right, middle)"
+ "Plot the confusion matrices for the training, validation, and test samples (left, right, middle)."
]
},
{
@@ -1599,14 +1689,21 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### 4.3.4. Investigating Errant Classifications: Look at the examples\n"
+ "> Figure 6: Confusion matrices for the training, validation, and test sets (left to right). The number of images with a given true (y-axis) and predicted (x-axis) classification is written in each box, and boxes are colored based on the number (from purple to yellow). The diagonal represents images that were correctly classified."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.3.4. Investigating errant classifications: look at the examples"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Choose a digit/class (human option/choice) for examination."
+ "Choose to investigate explore the digit classification `class_value` = 2."
]
},
{
@@ -1649,7 +1746,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "plot examples of false positives"
+ "Display 10 images that exemplify:\n",
+ " * true positives (correctly classified digits; a true 2 classified as 2);\n",
+ " * false positives (another digit classified as 2);\n",
+ " * true negatives (another digit classified as another digit); and\n",
+ " * false negatives (a true 2 classified as another digit)."
]
},
{
@@ -1711,6 +1812,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
+ "> Figure 7: Four panels of 10 images each, representing true positives (top), false positives (second), true negatives (third), and false negatives (bottom), for classification category 2.\n",
+ "\n",
"Plot histograms of images pixels of true positives, false positives, true negatives, and false negatives."
]
},
@@ -1773,6 +1876,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
+ "> Figure 8: Histograms of the pixel flux values for the images shown in Figure 7.\n",
+ "\n",
"## 5. Exercises for the Learner"
]
},
@@ -1825,21 +1930,7 @@
{
"cell_type": "markdown",
"metadata": {},
- "source": [
- "## 7. AI is math, not magic.\n",
- "\n",
- "AI is firmly based in math, computer science, and statistics. Additionally, some of the approaches are inspired by concepts or notions in biology (e.g., the computational neuron) and in physics (e.g., the reverse Boltzmann machine). \n",
- "\n",
- "Much of the jargon in AI is anthropomorphic, which can make it appear that some other than math is happening. For example, consider the following list of terms that are very often used in AI -- and what these terms actually mean mathematically.\n",
- "\n",
- "1. learn $\\rightarrow$ fit\n",
- "2. hallucinate/lie $\\rightarrow$ predict incorrectly\n",
- "3. understand $\\rightarrow$ model has converged\n",
- "4. cheat $\\rightarrow$ more efficiently guesses the best weight parameters of the model\n",
- "5. believe $\\rightarrow$ predict/infer based on statistical priors\n",
- "\n",
- "When we over-anthropomorphize this mathematical tool, we obfuscate how it actually works, and that makes it harder to build and refine models. That is, AI models are not 'learning' or 'understanding'; they are large-parameter models that are being fit to data. The only learning that's happening is what we do with these models."
- ]
+ "source": []
}
],
"metadata": {
From b97dbc45b35e9fe01ff1cbee9ba3ddeacd20eff9 Mon Sep 17 00:00:00 2001
From: Brian Nord <184985+bnord@users.noreply.github.com>
Date: Mon, 10 Feb 2025 08:23:26 -0600
Subject: [PATCH 4/5] commit for review
---
DP02_16a_Introduction_to_AI.ipynb.ipynb | 3212 +++++++++++++++++++++++
1 file changed, 3212 insertions(+)
create mode 100644 DP02_16a_Introduction_to_AI.ipynb.ipynb
diff --git a/DP02_16a_Introduction_to_AI.ipynb.ipynb b/DP02_16a_Introduction_to_AI.ipynb.ipynb
new file mode 100644
index 0000000..bceac82
--- /dev/null
+++ b/DP02_16a_Introduction_to_AI.ipynb.ipynb
@@ -0,0 +1,3212 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " \n",
+ " Introduction to AI-based Image Classification with Pytorch \n",
+ "Contact author: Brian Nord \n",
+ "Last verified to run: 2024-07-01 \n",
+ "LSST Science Pipelines version: Weekly 2024_16 \n",
+ "Container size: medium \n",
+ "Targeted learning level: beginner "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Description:** An introduction to the classification of images with AI-based classification algorithms."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Skills:** Examine AI training data, prepare it for a classification task, perform classification with a neural network, and examine the diagnostics of the classification task."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**LSST Data Products:** None; MNIST data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Packages:** numpy, matplotlib, sklearn, pytorch"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Credits and Acknowledgments:** We thank Ryan Lau and Melissa Graham for feedback on the notebook."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Get Support:**\n",
+ "Find DP0-related documentation and resources at dp0.lsst.io. Questions are welcome as new topics in the Support - Data Preview 0 Category of the Rubin Community Forum. Rubin staff will respond to all questions posted there."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 1. Introduction\n",
+ "\n",
+ "This Jupyter Notebook introduces artificial intelligence (AI)-based image classification. It demonstrates how to perform a few key steps:\n",
+ "1. examine and prepare data for classification;\n",
+ "2. train an AI algorithm;\n",
+ "3. plot diagnostics of the training performance;\n",
+ "4. initially assess those diagnostics. \n",
+ "\n",
+ "AI is a class of algorithms for building statistical models. These algorithms primarily use data for training, as opposed to models that use analytic formulae or models that are based on physical reasoning. Machine learning is a subclass of algorithms -- e.g., random forests. Deep learning is a subclass of algorithms -- e.g., neural networks. \n",
+ "\n",
+ "In this notebook, we use `pytorch`, which is currently the library most often used in deep learning studies. `Pytorch` is more complicated to use in some ways --- i.e., requires acclimating to how it handles tensors, data sets, and model building. At the same time, there is more flexibility for overall model development: you can more easily get creative with model structures. \n",
+ "\n",
+ "Instead of using DP0 data, this tutorials uses [MNIST AI benchmarking data](https://en.wikipedia.org/wiki/MNIST_database), a large database of handwritten digits that is commonly used for training and testing machine learning algorithms. It is simple to understand, so that users who are new to deep learning can focus on learning AI. Later tutorials in this series will use stars and galaxies drawn from DP0 data."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.1. AI is math, not magic.\n",
+ "\n",
+ "AI is firmly based in math, computer science, and statistics. Additionally, some of the approaches are inspired by concepts or notions in biology (e.g., the computational neuron) and in physics (e.g., the Reverse Boltzmann Machine). \n",
+ "\n",
+ "Much of the jargon in AI is anthropomorphic, which can make it appear that some other than math is happening. For example, consider the following list of terms that are very often used in AI -- and what these terms actually mean mathematically.\n",
+ "\n",
+ "1. `learn` $\\rightarrow$ fit\n",
+ "2. `hallucinate`/`lie` $\\rightarrow$ predict incorrectly\n",
+ "3. `understand` $\\rightarrow$ model fit has converged\n",
+ "4. `cheat` $\\rightarrow$ more efficiently guesses the best weight parameters of the model\n",
+ "5. `believe` $\\rightarrow$ predict/infer based on statistical priors\n",
+ "\n",
+ "When we over-anthropomorphize these mathematical concepts, we obfuscate how they actually work. That makes it harder to build and refine models. That is, AI models are not 'learning' or 'understanding'; they are merely large-parameter models that are being fit to data, and AI includes novel methods for that fitting process. The only learning that's happening is what humans do with these models.\n",
+ "\n",
+ "Many of the most useful AI-related terms are defined throughout this tutorial."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.2. Import packages\n",
+ "\n",
+ "[`numpy`](https://numpy.org/) is used for computations and mathematical operations on multi-dimensional arrays.\n",
+ "\n",
+ "[`matplotlib`](https://matplotlib.org/) is a plot library. \n",
+ "\n",
+ "[`sklearn`](https://scikit-learn.org/stable/) is a library for machine learning.\n",
+ "\n",
+ "[`torch`](https://www.pytorch.org) is used for fast tensor operations --- often used for building neural network models."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:40.895553Z",
+ "iopub.status.busy": "2025-02-09T23:00:40.895367Z",
+ "iopub.status.idle": "2025-02-09T23:00:49.712635Z",
+ "shell.execute_reply": "2025-02-09T23:00:49.711587Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:40.895535Z"
+ },
+ "id": "JyaaGkFE8VOl"
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import time\n",
+ "import datetime\n",
+ "import os\n",
+ "\n",
+ "import matplotlib\n",
+ "import matplotlib.pyplot as plt\n",
+ "from matplotlib.pyplot import cm\n",
+ "from matplotlib.colors import LogNorm\n",
+ "import seaborn as sns\n",
+ "\n",
+ "from sklearn.metrics import confusion_matrix\n",
+ "\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.optim as optim\n",
+ "import torch.nn.functional as F\n",
+ "import torchvision\n",
+ "from torch.utils.data import Dataset, DataLoader, Subset, random_split"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.3. Define functions"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The following functions are defined and used throughout this notebook. It is not necessary to understand exactly what every function does to proceed with this tutorial. Execute all cells and move on to Section 2."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:49.714597Z",
+ "iopub.status.busy": "2025-02-09T23:00:49.713821Z",
+ "iopub.status.idle": "2025-02-09T23:00:49.863235Z",
+ "shell.execute_reply": "2025-02-09T23:00:49.862202Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:49.714555Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def normalizeInputs(x_temp, input_minimum, input_maximum):\n",
+ " \"\"\"Normalize a datum that is an input to the neural network\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " x_temp: `numpy.array`\n",
+ " image data\n",
+ " input_minimum: `float`\n",
+ " minimum value for normalization\n",
+ " input_maximum: `float`\n",
+ " maximum value for normalization\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " x_temp_norm: `numpy.array`\n",
+ " normalized image data\n",
+ " \"\"\"\n",
+ " x_temp_norm = (x_temp - input_minimum)/input_maximum\n",
+ " return x_temp_norm"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:49.866230Z",
+ "iopub.status.busy": "2025-02-09T23:00:49.865802Z",
+ "iopub.status.idle": "2025-02-09T23:00:49.998210Z",
+ "shell.execute_reply": "2025-02-09T23:00:49.997647Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:49.866185Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def createFileUidTimestamp():\n",
+ " \"\"\"Create a timestamp for a filename.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " None\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " file_uid_timestamp : `string`\n",
+ " String from date and time.\n",
+ " \"\"\"\n",
+ " file_uid_timestamp = datetime.datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
+ " return file_uid_timestamp"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:49.998965Z",
+ "iopub.status.busy": "2025-02-09T23:00:49.998779Z",
+ "iopub.status.idle": "2025-02-09T23:00:50.155942Z",
+ "shell.execute_reply": "2025-02-09T23:00:50.154901Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:49.998948Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def createFileName(file_prefix=\"\", file_location=\"Data/Sandbox/\",\n",
+ " file_suffix=\"\", useuid=True, verbose=True):\n",
+ " \"\"\"Create a file name.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " file_prefix: `string`\n",
+ " prefix of file name\n",
+ " file_location: `string`\n",
+ " path to file\n",
+ " file_suffix: `string`\n",
+ " suffix/extension of file name\n",
+ " useuid: 'bool'\n",
+ " choose to use a unique id\n",
+ " verbose: 'bool'\n",
+ " choose to print the file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " file_final: `string`\n",
+ " filename used for saving\n",
+ " \"\"\"\n",
+ " if useuid:\n",
+ " file_uid = createFileUidTimestamp()\n",
+ " else:\n",
+ " file_uid = \"\"\n",
+ "\n",
+ " file_final = file_location + file_prefix + \"_\" + file_uid + file_suffix\n",
+ "\n",
+ " if verbose:\n",
+ " print(file_final)\n",
+ "\n",
+ " return file_final"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:50.157469Z",
+ "iopub.status.busy": "2025-02-09T23:00:50.157104Z",
+ "iopub.status.idle": "2025-02-09T23:00:50.298215Z",
+ "shell.execute_reply": "2025-02-09T23:00:50.297620Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:50.157432Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def plotArrayImageExamples(subset_train,\n",
+ " num_row=3, num_col=3,\n",
+ " object_index_start=0,\n",
+ " save_file=False,\n",
+ " file_prefix=\"ImageExamples\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot an array of examples of images and labels\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " subset_train: `numpy.ndarray`\n",
+ " training data images\n",
+ " num_row: `int`, optional\n",
+ " number of rows to plot\n",
+ " num_col: `int`, optional\n",
+ " number of columns to plot\n",
+ " object_index_start: `int`, optional\n",
+ " starting index for set of images to plot\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " From: https://pytorch.org/tutorials/beginner/basics/data_tutorial.html\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " labels_map = {\n",
+ " 0: \"0\",\n",
+ " 1: \"1\",\n",
+ " 2: \"2\",\n",
+ " 3: \"3\",\n",
+ " 4: \"4\",\n",
+ " 5: \"5\",\n",
+ " 6: \"6\",\n",
+ " 7: \"7\",\n",
+ " 8: \"8\",\n",
+ " 9: \"9\",\n",
+ " }\n",
+ "\n",
+ " figure = plt.figure(figsize=(8, 8))\n",
+ "\n",
+ " for i in range(1, num_row * num_col + 1):\n",
+ " sample_idx = object_index_start + i\n",
+ " img, label = subset_train[sample_idx]\n",
+ " figure.add_subplot(num_row, num_col, i)\n",
+ " plt.title(labels_map[label])\n",
+ " plt.axis(\"off\")\n",
+ " plt.imshow(img.squeeze(), cmap=\"gray\")\n",
+ "\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:50.299193Z",
+ "iopub.status.busy": "2025-02-09T23:00:50.298861Z",
+ "iopub.status.idle": "2025-02-09T23:00:50.463522Z",
+ "shell.execute_reply": "2025-02-09T23:00:50.462425Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:50.299174Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
pycodestyleINFO: 20:35: W291 trailing whitespace
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "def plotArrayHistogramExamples(subset_train,\n",
+ " num_row=3, num_col=3,\n",
+ " n_bins=10,\n",
+ " object_index_start=0,\n",
+ " save_file=False,\n",
+ " file_prefix=\"HistogramExamples\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot histograms of image pixel values\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " subset_train: `numpy.ndarray`\n",
+ " training data images\n",
+ " num_row: `int`, optional\n",
+ " number of rows to plot\n",
+ " num_col: `int`, optional\n",
+ " number of columns to plot\n",
+ " n_bins: `int`, optional\n",
+ " number of bins in histogram \n",
+ " object_index_start: `int`, optional\n",
+ " starting index for set of images to plot\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " labels_map = {\n",
+ " 0: \"0\",\n",
+ " 1: \"1\",\n",
+ " 2: \"2\",\n",
+ " 3: \"3\",\n",
+ " 4: \"4\",\n",
+ " 5: \"5\",\n",
+ " 6: \"6\",\n",
+ " 7: \"7\",\n",
+ " 8: \"8\",\n",
+ " 9: \"9\",\n",
+ " }\n",
+ "\n",
+ " fig, axes = plt.subplots(num_row, num_col,\n",
+ " figsize=(1.5*num_col, 2*num_row))\n",
+ "\n",
+ " for i in range(0, num_row * num_col + 0):\n",
+ " sample_idx = object_index_start + i\n",
+ " img, label = subset_train[sample_idx]\n",
+ " ax = axes[i//num_col, i%num_col]\n",
+ " img_temp = img[0, :, :]\n",
+ " img_temp = np.array(img_temp).flat\n",
+ " ax.hist(img_temp, bins=n_bins, color='gray')\n",
+ " ax.set_title(labels_map[label])\n",
+ " ax.set_xlabel(\"Pixel Values\")\n",
+ "\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:50.465449Z",
+ "iopub.status.busy": "2025-02-09T23:00:50.465057Z",
+ "iopub.status.idle": "2025-02-09T23:00:50.612203Z",
+ "shell.execute_reply": "2025-02-09T23:00:50.611165Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:50.465413Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def predict(dataloader, model, dataset_type):\n",
+ " \"\"\"Predict labels of inputs\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " dataloader: `numpy.ndarray`\n",
+ " training data images\n",
+ " model: `int`, optional\n",
+ " number of rows to plot\n",
+ " dataset_type: `int`, optional\n",
+ " number of columns to plot\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " y_prob_list: `numpy.ndarray`\n",
+ " probabilities for each class for each input\n",
+ " y_choice_list: `numpy.ndarray`\n",
+ " highest-probability class for each input\n",
+ " y_true_list: `numpy.ndarray`\n",
+ " true class for each input\n",
+ " x_list: `numpy.ndarray`\n",
+ " input\n",
+ " \"\"\"\n",
+ " size = len(dataloader.dataset)\n",
+ " num_batches = len(dataloader)\n",
+ " model.eval()\n",
+ "\n",
+ " y_prob_list = []\n",
+ " y_choice_list = []\n",
+ " y_true_list = []\n",
+ " x_list = []\n",
+ "\n",
+ " i = 0\n",
+ " loss, accuracy = 0, 0\n",
+ " with torch.no_grad():\n",
+ " for inputs, labels in dataloader:\n",
+ " inputs = inputs.to(device)\n",
+ " labels = labels.to(device)\n",
+ "\n",
+ " y = model(inputs)\n",
+ " y_prob = torch.softmax(y, dim=1)\n",
+ " y_choice = (torch.max(torch.exp(y), 1)[1]).data.cpu().numpy()\n",
+ "\n",
+ " loss += loss_fn(y, labels).item()\n",
+ " loss /= num_batches\n",
+ " accuracy_temp = y.argmax(1) == labels\n",
+ " accuracy += accuracy_temp.type(torch.float).sum().item()\n",
+ " accuracy /= size\n",
+ "\n",
+ " y_prob_list.append(y_prob.detach())\n",
+ " y_choice_list.append(y_choice)\n",
+ " y_true_list.append(labels)\n",
+ " x_list.append(inputs)\n",
+ " labels = labels.data.cpu().numpy()\n",
+ "\n",
+ " i += 1\n",
+ "\n",
+ " y_prob_list = np.array(y_prob_list)\n",
+ " y_choice_list = np.array(y_choice_list)\n",
+ " y_true_list = np.array(y_true_list)\n",
+ " x_list = np.array(x_list)\n",
+ "\n",
+ " y_prob_list = np.squeeze(y_prob_list)\n",
+ " y_choice_list = np.squeeze(y_choice_list)\n",
+ " y_true_list = np.squeeze(y_true_list)\n",
+ " x_list = np.squeeze(x_list)\n",
+ "\n",
+ " print(f\"{dataset_type : <10} data set ...\\\n",
+ " Accuracy: {(100*accuracy):>0.1f}%, Avg loss: {loss:>8f}\")\n",
+ "\n",
+ " return y_prob_list, y_choice_list, y_true_list, x_list"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:50.614078Z",
+ "iopub.status.busy": "2025-02-09T23:00:50.613693Z",
+ "iopub.status.idle": "2025-02-09T23:00:50.840445Z",
+ "shell.execute_reply": "2025-02-09T23:00:50.839384Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:50.614040Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def plotPredictionHistogram(y_prediction_a, y_prediction_b=None,\n",
+ " y_prediction_c=None, n_classes=None,\n",
+ " n_objects_a=None, n_colors=None,\n",
+ " title_a=None, title_b=None,\n",
+ " title_c=None, label_a=None,\n",
+ " label_b=None, label_c=None,\n",
+ " alpha=1.0, figsize=(12, 5),\n",
+ " save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " xlabel_plot=\"Class label\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot histogram of predicted labels\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " y_prediction_a: `numpy.ndarray`\n",
+ " y_prediction_b: `numpy.ndarray`, optional\n",
+ " y_prediction_c: `numpy.ndarray`, optional\n",
+ " n_classes: `int`, optional\n",
+ " n_objects_a: `int`, optional\n",
+ " n_colors: `int`, optional\n",
+ " title_a: `string`, optional\n",
+ " title_b: `string`, optional\n",
+ " title_c: `string`, optional\n",
+ " label_a: `string`, optional\n",
+ " label_b: `string`, optional\n",
+ " label_c: `string`, optional\n",
+ " alpha: `float`, optional\n",
+ " transparency\n",
+ " figsize: `tuple`, optional\n",
+ " figure size\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " ndim = y_prediction_a.ndim\n",
+ "\n",
+ " if ndim == 2:\n",
+ " fig, (axa, axb, axc) = plt.subplots(1, 3, figsize=figsize)\n",
+ " fig.subplots_adjust(wspace=0.35)\n",
+ " elif ndim == 1:\n",
+ " fig, ax = plt.subplots(figsize=figsize)\n",
+ "\n",
+ " shape_a = np.shape(y_prediction_a)\n",
+ "\n",
+ " if n_objects_a is None:\n",
+ " n_objects_a = shape_a[0]\n",
+ "\n",
+ " if ndim == 2:\n",
+ " if n_classes is None:\n",
+ " n_classes = shape_a[1]\n",
+ " if n_colors is None:\n",
+ " n_colors = n_classes\n",
+ " elif ndim == 1:\n",
+ " if n_colors is None:\n",
+ " n_colors = 1\n",
+ "\n",
+ " if ndim == 2:\n",
+ " colors = cm.Purples(np.linspace(0, 1, n_colors))\n",
+ " xlabel = \"Probability for Each Class\"\n",
+ "\n",
+ " axa.set_ylim(0, n_objects_a)\n",
+ " axa.set_xlabel(xlabel)\n",
+ " axa.set_title(title_a)\n",
+ "\n",
+ " for i in np.arange(n_classes):\n",
+ " axa.hist(y_prediction_a[:, i], alpha=alpha,\n",
+ " color=colors[i], label=\"'\" + str(i) + \"'\")\n",
+ "\n",
+ " if y_prediction_b is not None:\n",
+ " shape_b = np.shape(y_prediction_b)\n",
+ " axb.set_ylim(0, shape_b[0])\n",
+ " axb.set_xlabel(xlabel)\n",
+ " axb.set_title(title_b)\n",
+ "\n",
+ " for i in np.arange(n_classes):\n",
+ " axb.hist(y_prediction_b[:, i], alpha=alpha,\n",
+ " color=colors[i], label=\"'\" + str(i) + \"'\")\n",
+ "\n",
+ " if y_prediction_c is not None:\n",
+ " shape_c = np.shape(y_prediction_c)\n",
+ " axc.set_ylim(0, shape_c[0])\n",
+ " axc.set_xlabel(xlabel)\n",
+ " axc.set_title(title_c)\n",
+ "\n",
+ " for i in np.arange(n_classes):\n",
+ " axc.hist(y_prediction_c[:, i], alpha=alpha,\n",
+ " color=colors[i], label=\"'\" + str(i) + \"'\")\n",
+ "\n",
+ " elif ndim == 1:\n",
+ " ya, xa, _ = plt.hist(y_prediction_a, alpha=alpha, color='orange',\n",
+ " label=label_a)\n",
+ " y_max_list = [max(ya)]\n",
+ "\n",
+ " if y_prediction_b is not None:\n",
+ " yb, xb, _ = plt.hist(y_prediction_b, alpha=alpha, color='green',\n",
+ " label=label_b)\n",
+ " y_max_list.append(max(yb))\n",
+ "\n",
+ " if y_prediction_c is not None:\n",
+ " yc, xc, _ = plt.hist(y_prediction_c, alpha=alpha, color='purple',\n",
+ " label=label_c)\n",
+ " y_max_list.append(max(yc))\n",
+ "\n",
+ " plt.ylim(0, np.max(y_max_list)*1.1)\n",
+ " plt.xlabel(xlabel_plot)\n",
+ "\n",
+ " plt.legend(loc='upper right')\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:50.842666Z",
+ "iopub.status.busy": "2025-02-09T23:00:50.842278Z",
+ "iopub.status.idle": "2025-02-09T23:00:51.020147Z",
+ "shell.execute_reply": "2025-02-09T23:00:51.019064Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:50.842629Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def plotTrueLabelHistogram(y_prediction_a, y_prediction_b=None,\n",
+ " y_prediction_c=None, n_classes=None,\n",
+ " n_objects_a=None, n_colors=None,\n",
+ " title_a=None, title_b=None,\n",
+ " title_c=None, label_a=None,\n",
+ " label_b=None, label_c=None,\n",
+ " alpha=1.0, figsize=(12, 5),\n",
+ " save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " xlabel_plot=\"Class label\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot histogram of predicted labels\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " y_prediction_a: `numpy.ndarray`\n",
+ " y_prediction_b: `numpy.ndarray`, optional\n",
+ " y_prediction_c: `numpy.ndarray`, optional\n",
+ " n_classes: `int`, optional\n",
+ " n_objects_a: `int`, optional\n",
+ " n_colors: `int`, optional\n",
+ " title_a: `string`, optional\n",
+ " title_b: `string`, optional\n",
+ " title_c: `string`, optional\n",
+ " label_a: `string`, optional\n",
+ " label_b: `string`, optional\n",
+ " label_c: `string`, optional\n",
+ " alpha: `float`, optional\n",
+ " transparency\n",
+ " figsize: `tuple`, optional\n",
+ " figure size\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ "\n",
+ " ndim = y_prediction_a.ndim\n",
+ "\n",
+ " if ndim == 2:\n",
+ " fig, (axa, axb, axc) = plt.subplots(1, 3, figsize=figsize)\n",
+ " fig.subplots_adjust(wspace=0.35)\n",
+ " elif ndim == 1:\n",
+ " fig, ax = plt.subplots(figsize=figsize)\n",
+ "\n",
+ " shape_a = np.shape(y_prediction_a)\n",
+ "\n",
+ " if n_objects_a is None:\n",
+ " n_objects_a = shape_a[0]\n",
+ "\n",
+ " if ndim == 2:\n",
+ " if n_classes is None:\n",
+ " n_classes = shape_a[1]\n",
+ " if n_colors is None:\n",
+ " n_colors = n_classes\n",
+ " elif ndim == 1:\n",
+ " if n_colors is None:\n",
+ " n_colors = 1\n",
+ "\n",
+ " if ndim == 2:\n",
+ " colors = cm.Purples(np.linspace(0, 1, n_colors))\n",
+ " xlabel = \"Probability for Each Class\"\n",
+ "\n",
+ " axa.set_ylim(0, n_objects_a)\n",
+ " axa.set_xlabel(xlabel)\n",
+ " axa.set_title(title_a)\n",
+ "\n",
+ " for i in np.arange(n_classes):\n",
+ " axa.hist(y_prediction_a[:, i], alpha=alpha,\n",
+ " color=colors[i], label=\"'\" + str(i) + \"'\")\n",
+ "\n",
+ " if y_prediction_b is not None:\n",
+ " shape_b = np.shape(y_prediction_b)\n",
+ " axb.set_ylim(0, shape_b[0])\n",
+ " axb.set_xlabel(xlabel)\n",
+ " axb.set_title(title_b)\n",
+ "\n",
+ " for i in np.arange(n_classes):\n",
+ " axb.hist(y_prediction_b[:, i], alpha=alpha,\n",
+ " color=colors[i], label=\"'\" + str(i) + \"'\")\n",
+ "\n",
+ " if y_prediction_c is not None:\n",
+ " shape_c = np.shape(y_prediction_c)\n",
+ " axc.set_ylim(0, shape_c[0])\n",
+ " axc.set_xlabel(xlabel)\n",
+ " axc.set_title(title_c)\n",
+ "\n",
+ " for i in np.arange(n_classes):\n",
+ " axc.hist(y_prediction_c[:, i], alpha=alpha,\n",
+ " color=colors[i], label=\"'\" + str(i) + \"'\")\n",
+ "\n",
+ " elif ndim == 1:\n",
+ " ya, xa, _ = plt.hist(y_prediction_a, alpha=alpha, color='orange',\n",
+ " label=label_a)\n",
+ " y_max_list = [max(ya)]\n",
+ "\n",
+ " if y_prediction_b is not None:\n",
+ " yb, xb, _ = plt.hist(y_prediction_b, alpha=alpha, color='green',\n",
+ " label=label_b)\n",
+ " y_max_list.append(max(yb))\n",
+ "\n",
+ " if y_prediction_c is not None:\n",
+ " yc, xc, _ = plt.hist(y_prediction_c, alpha=alpha, color='purple',\n",
+ " label=label_c)\n",
+ " y_max_list.append(max(yc))\n",
+ "\n",
+ " plt.ylim(0, np.max(y_max_list)*1.1)\n",
+ " plt.xlabel(xlabel_plot)\n",
+ "\n",
+ " plt.legend(loc='upper right')\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:51.021678Z",
+ "iopub.status.busy": "2025-02-09T23:00:51.021279Z",
+ "iopub.status.idle": "2025-02-09T23:00:51.185198Z",
+ "shell.execute_reply": "2025-02-09T23:00:51.184585Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:51.021643Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def plotLossHistory(history, figsize=(8, 5),\n",
+ " save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot loss history of the model as function of epoch\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " history: `keras.src.callbacks.history.History`\n",
+ " keras callback history object containing the losses at each epoch\n",
+ " figsize: `tuple`, optional\n",
+ " figure size\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=figsize)\n",
+ "\n",
+ " loss_tra = np.array(history['loss'])\n",
+ " loss_val = np.array(history['val_loss'])\n",
+ " loss_dif = loss_val - loss_tra\n",
+ "\n",
+ " ax1.plot(loss_tra, label='Training')\n",
+ " ax1.plot(loss_val, label='Validation')\n",
+ " ax1.legend()\n",
+ "\n",
+ " ax2.plot(loss_dif, color='red', label='residual')\n",
+ " ax2.axhline(y=0, color='grey', linestyle='dashed', label='zero bias')\n",
+ " ax2.sharex(ax1)\n",
+ " ax2.legend()\n",
+ "\n",
+ " ax1.set_title('Loss History')\n",
+ " ax1.set_ylabel('Loss')\n",
+ " ax2.set_ylabel('Loss Residual')\n",
+ " ax2.set_xlabel('Epoch')\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ "\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:51.186307Z",
+ "iopub.status.busy": "2025-02-09T23:00:51.186054Z",
+ "iopub.status.idle": "2025-02-09T23:00:51.343743Z",
+ "shell.execute_reply": "2025-02-09T23:00:51.342674Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:51.186287Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def plotArrayImageConfusion(x_tra, y_tra, y_pred_tra_topchoice,\n",
+ " title_main=None, num=10,\n",
+ " save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot images of examples objects that are misclassified.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " x_tra: `numpy.ndarray`\n",
+ " training image data\n",
+ " y_tra: `numpy.ndarray`\n",
+ " training label data\n",
+ " y_pred_tra_topchoice: `numpy.ndarray`\n",
+ " top choice of the predicted labels\n",
+ " title_main: `string`, optional\n",
+ " title for the plot\n",
+ " num: `int`, optional\n",
+ " number of examples\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " num_row = 2\n",
+ " num_col = 5\n",
+ " images = x_tra[:num]\n",
+ " labels_true = y_tra[:num]\n",
+ " labels_pred = y_pred_tra_topchoice[:num]\n",
+ "\n",
+ " fig, axes = plt.subplots(num_row, num_col,\n",
+ " figsize=(1.5*num_col, 2*num_row))\n",
+ "\n",
+ " fig.patch.set_linewidth(5)\n",
+ " fig.patch.set_edgecolor('cornflowerblue')\n",
+ "\n",
+ " for i in range(num):\n",
+ " ax = axes[i//num_col, i%num_col]\n",
+ " ax.imshow(images[i], cmap='gray')\n",
+ " ax.set_title(r'True: {}'.format(labels_true[i]) + '\\n'\n",
+ " + 'Pred: {}'.format(labels_pred[i]))\n",
+ "\n",
+ " fig.suptitle(title_main)\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ "\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:51.348809Z",
+ "iopub.status.busy": "2025-02-09T23:00:51.348383Z",
+ "iopub.status.idle": "2025-02-09T23:00:51.591750Z",
+ "shell.execute_reply": "2025-02-09T23:00:51.591165Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:51.348771Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def plotArrayHistogramConfusion(x_tra, y_tra, y_pred_tra_topchoice,\n",
+ " title_main=None, num=10,\n",
+ " save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot histograms of pixel values for images that are misclassified.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " x_tra: `numpy.ndarray`\n",
+ " training image data\n",
+ " y_tra: `numpy.ndarray`\n",
+ " training label data\n",
+ " y_pred_tra_topchoice: `numpy.ndarray`\n",
+ " top choice of the predicted labels\n",
+ " title_main: `string`, optional\n",
+ " title of plot\n",
+ " num: `int`, optional\n",
+ " number of examples\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " n_bins = 10\n",
+ " num_row = 2\n",
+ " num_col = 5\n",
+ " images = x_tra[:num]\n",
+ " labels_true = y_tra[:num]\n",
+ " labels_pred = y_pred_tra_topchoice[:num]\n",
+ "\n",
+ " fig, axes = plt.subplots(num_row, num_col,\n",
+ " figsize=(1.5*num_col, 2*num_row))\n",
+ "\n",
+ " fig.patch.set_linewidth(5)\n",
+ " fig.patch.set_edgecolor('cornflowerblue')\n",
+ "\n",
+ " for i in range(num):\n",
+ " ax = axes[i//num_col, i%num_col]\n",
+ " images_temp = images[i, :, :].flat\n",
+ " ax.hist(images_temp, bins=n_bins, color='gray')\n",
+ " ax.set_title(r'True: {}'.format(labels_true[i]) + '\\n'\n",
+ " + 'Pred: {}'.format(labels_pred[i]))\n",
+ " ax.set_xlabel('Pixel Values')\n",
+ "\n",
+ " fig.suptitle(title_main)\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ "\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.4. Define paths for data and figures"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Neural network training (i.e., model fitting) typically requires many numerical experiments to achieve an ideal model. To facilitate the comparison of these experiments/models, it is helpful to organize data carefully. \n",
+ "\n",
+ "Set the variable `run_label` for each training run. \n",
+ "Set paths for the model weight parameters and diagnostic figures, and\n",
+ "save these paths in the dictionary `path_dict` to facilitate passing information to plotting functions.\n",
+ "Check whether the paths exist, and if not, create them."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:51.592497Z",
+ "iopub.status.busy": "2025-02-09T23:00:51.592305Z",
+ "iopub.status.idle": "2025-02-09T23:00:51.743517Z",
+ "shell.execute_reply": "2025-02-09T23:00:51.742419Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:51.592481Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "run_label = \"Run000\"\n",
+ "\n",
+ "path_temp = os.getenv(\"HOME\") + '/dp02_16a_temp/'\n",
+ "path_dict = {'run_label': run_label,\n",
+ " 'dir_data_model': path_temp + \"Data/Models/\",\n",
+ " 'dir_data_figures': path_temp + \"Data/Figures/\",\n",
+ " 'file_model_prefix': \"Model\",\n",
+ " 'file_figure_prefix': \"Figure\",\n",
+ " 'file_figure_suffix': \".png\",\n",
+ " 'file_model_suffix': \".pt\"\n",
+ " }\n",
+ "del path_temp\n",
+ "\n",
+ "if not os.path.exists(path_dict['dir_data_model']):\n",
+ " os.makedirs(path_dict['dir_data_model'])\n",
+ "\n",
+ "if not os.path.exists(path_dict['dir_data_figures']):\n",
+ " os.makedirs(path_dict['dir_data_figures'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 2. Load and Prepare data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.1 Data Set: MNIST Handwritten Digits"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The [`MNIST handwritten digits dataset`](https://ieeexplore.ieee.org/document/6296535) comprises 10 classes --- one for each digit. This is a useful dataset for learning the basics of neural networks and other AI algorithms. MNIST is one of a few canonical AI benchmark data sets for image classification. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.2. Obtain the dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "`pytorch` has a simple function to download the MNIST data to your local server for free. While downloading, we use the `transforms` method to normalize the data; this makes the model training more efficient."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 2.2.1 Normalize data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Create a `transform` that is used to convert the data sets to tensors that can be used in `pytorch`. This also transforms all the data from the range $[0,255]$ to $[0.,1.]$."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:51.745068Z",
+ "iopub.status.busy": "2025-02-09T23:00:51.744707Z",
+ "iopub.status.idle": "2025-02-09T23:00:51.884959Z",
+ "shell.execute_reply": "2025-02-09T23:00:51.884286Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:51.745032Z"
+ },
+ "id": "kMzao03zcF5i",
+ "outputId": "0401d0af-d607-436a-908f-385ffc85812c"
+ },
+ "outputs": [],
+ "source": [
+ "transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 2.2.2 Download data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Download the data from a remote reserver."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:51.886408Z",
+ "iopub.status.busy": "2025-02-09T23:00:51.885839Z",
+ "iopub.status.idle": "2025-02-09T23:00:52.170282Z",
+ "shell.execute_reply": "2025-02-09T23:00:52.169224Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:51.886378Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "dataset = torchvision.datasets.MNIST(root='./newdata', train=True,\n",
+ " download=True, transform=transform)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.3. Split data into training, validation, and testing"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Split for a proper 'blind' analysis and optimization of an AI model.\n",
+ "\n",
+ "There are three primary data sets used in model development and optimization:\n",
+ "\n",
+ "* `Training` (with filename tag `_tra`) data is used directly by the algorithm to update the parameters of the AI model -- e.g., the weights of the computational neurons on the edges in neural networks.\n",
+ "* `Validation` (`_val`) data is used indirectly to update the hyperparameters of the AI model -- e.g., the batch size (`batchsize`), the learning rate, or the layers in the architecture of a neural network. Each time the neural network has completed training with the training data, the human looks at those diagnostics when run on the training and the validation data.\n",
+ "* `Test(ing)` (`_tes`) data is only used when the model is trained and validated and will no longer be update or further trained. This is the data that you would consider to be the new data that has not been examined before -- e.g., newly observed data that has not bene previously characterized."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We use most of the data for training to maximize the accuracy and generalization of the model. We use a small amount of validation data, because only a little bit is needed to check the model optimization during training. We also use only a small amount of testing data, assuming that new data sets to examine are smaller than existing training data sets."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Set the fractions for the training, validation, and test sets."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:52.172134Z",
+ "iopub.status.busy": "2025-02-09T23:00:52.171340Z",
+ "iopub.status.idle": "2025-02-09T23:00:52.297172Z",
+ "shell.execute_reply": "2025-02-09T23:00:52.296634Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:52.172083Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "fraction_tra = 0.8\n",
+ "fraction_val = 0.1\n",
+ "fraction_tes = 0.1"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Split the data according to those fractions."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:52.297980Z",
+ "iopub.status.busy": "2025-02-09T23:00:52.297800Z",
+ "iopub.status.idle": "2025-02-09T23:00:52.444477Z",
+ "shell.execute_reply": "2025-02-09T23:00:52.443440Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:52.297963Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "fraction_list = [fraction_tra, fraction_val, fraction_tes]\n",
+ "\n",
+ "data_tra_full, data_val_full, data_tes_full = \\\n",
+ " torch.utils.data.random_split(dataset, fraction_list)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Print the data set sizes to make sure there's enough data for optimizing the model. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:52.445854Z",
+ "iopub.status.busy": "2025-02-09T23:00:52.445490Z",
+ "iopub.status.idle": "2025-02-09T23:00:52.586869Z",
+ "shell.execute_reply": "2025-02-09T23:00:52.586295Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:52.445819Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "training: 48000\n",
+ "validation: 6000\n",
+ "test 6000\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"training:\", len(data_tra_full.indices))\n",
+ "print(\"validation:\", len(data_val_full.indices))\n",
+ "print(\"test\", len(data_tes_full.indices))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.4. Create a subset of the training data."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use only a subset of each data to make the training faster for this tutorial. Consider increasing the sizes of these data sets in your exploration of the tutorial elements."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:52.587896Z",
+ "iopub.status.busy": "2025-02-09T23:00:52.587639Z",
+ "iopub.status.idle": "2025-02-09T23:00:52.723891Z",
+ "shell.execute_reply": "2025-02-09T23:00:52.722872Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:52.587879Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "subset_size = 5000\n",
+ "\n",
+ "subset_indices_tra = np.arange(subset_size)\n",
+ "subset_indices_val = np.arange(subset_size)\n",
+ "subset_indices_tes = np.arange(subset_size)\n",
+ "\n",
+ "data_tra = Subset(data_tra_full, subset_indices_tra)\n",
+ "data_val = Subset(data_val_full, subset_indices_val)\n",
+ "data_tes = Subset(data_tes_full, subset_indices_tes)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.4. Examine raw data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Review the raw data shapes by looking at a single datum in the training set. Each datum in the training set is an image-label pair. The image is a tensor, and the label is an integer. The image size in part determines the depth (number of layers) of the neural network. This will be discussed in a later section on model training."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:52.725316Z",
+ "iopub.status.busy": "2025-02-09T23:00:52.724942Z",
+ "iopub.status.idle": "2025-02-09T23:00:52.875460Z",
+ "shell.execute_reply": "2025-02-09T23:00:52.874383Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:52.725278Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The image shape is torch.Size([1, 28, 28]).\n",
+ "The label for this image is 6.\n"
+ ]
+ }
+ ],
+ "source": [
+ "sample_index = 0\n",
+ "image, label = data_tra[sample_index]\n",
+ "print(f\"The image shape is {image.shape}.\")\n",
+ "print(f\"The label for this image is {label}.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Print the image labels and their corresponding indices within a dataset object. This is useful for verifying that you undesrtand your data set."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:52.876528Z",
+ "iopub.status.busy": "2025-02-09T23:00:52.876301Z",
+ "iopub.status.idle": "2025-02-09T23:00:53.000692Z",
+ "shell.execute_reply": "2025-02-09T23:00:53.000075Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:52.876509Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{0: '0 - zero',\n",
+ " 1: '1 - one',\n",
+ " 2: '2 - two',\n",
+ " 3: '3 - three',\n",
+ " 4: '4 - four',\n",
+ " 5: '5 - five',\n",
+ " 6: '6 - six',\n",
+ " 7: '7 - seven',\n",
+ " 8: '8 - eight',\n",
+ " 9: '9 - nine'}"
+ ]
+ },
+ "execution_count": 23,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "{v: k for k, v in dataset.class_to_idx.items()}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot examples of the raw data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:53.001744Z",
+ "iopub.status.busy": "2025-02-09T23:00:53.001327Z",
+ "iopub.status.idle": "2025-02-09T23:00:53.833152Z",
+ "shell.execute_reply": "2025-02-09T23:00:53.832087Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:53.001716Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"Example_Image_Array\"\\\n",
+ " + \"_\" + path_dict['run_label']\n",
+ "\n",
+ "plotArrayImageExamples(data_tra,\n",
+ " num_row=4, num_col=3,\n",
+ " save_file=False,\n",
+ " object_index_start=0,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 1: Two rows of five images, each a handwritten number in white on a black background."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot the distributions of pixel values to understand the data further. All pixel data has been normalized, and pixels have values between 0 and 1 only.\n",
+ "\n",
+ "The distribution of pixel values matches the images shown above: mostly black (values near 0) pixels, with some white (values near 1), and a few grey (values in between 0 and 1)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:53.834798Z",
+ "iopub.status.busy": "2025-02-09T23:00:53.834404Z",
+ "iopub.status.idle": "2025-02-09T23:00:55.262207Z",
+ "shell.execute_reply": "2025-02-09T23:00:55.261170Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:53.834762Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"Example_Histogram_Array\"\\\n",
+ " + \"_\" + path_dict['run_label']\n",
+ "\n",
+ "plotArrayHistogramExamples(data_tra,\n",
+ " num_row=2, num_col=5,\n",
+ " save_file=False,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 2: Two rows of five plots, each showing the distribution of pixel values (number of pixels of a given value) for the handwritten digit images shown in Figure 1."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2024-12-15T20:26:20.181221Z",
+ "iopub.status.busy": "2024-12-15T20:26:20.180895Z",
+ "iopub.status.idle": "2024-12-15T20:26:20.186166Z",
+ "shell.execute_reply": "2024-12-15T20:26:20.185452Z",
+ "shell.execute_reply.started": "2024-12-15T20:26:20.181199Z"
+ }
+ },
+ "source": [
+ "Use the `plotPredictionHistogram` function to plot histograms of true label distributions by class."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:55.263752Z",
+ "iopub.status.busy": "2025-02-09T23:00:55.263376Z",
+ "iopub.status.idle": "2025-02-09T23:00:57.303748Z",
+ "shell.execute_reply": "2025-02-09T23:00:57.302743Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:55.263717Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "y_tra = []\n",
+ "y_val = []\n",
+ "y_tes = []\n",
+ "\n",
+ "for i in np.arange(subset_size):\n",
+ " image, label_tra = data_tra[i]\n",
+ " image, label_val = data_val[i]\n",
+ " image, label_tes = data_tes[i]\n",
+ " y_tra.append(label_tra)\n",
+ " y_val.append(label_val)\n",
+ " y_tes.append(label_tes)\n",
+ "\n",
+ "y_tra = np.array(y_tra)\n",
+ "y_val = np.array(y_val)\n",
+ "y_tes = np.array(y_tes)\n",
+ "\n",
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"Histograms_true_class\"\\\n",
+ " + \"_\" + path_dict['run_label']\n",
+ "\n",
+ "plotPredictionHistogram(y_tra,\n",
+ " y_prediction_b=y_val,\n",
+ " y_prediction_c=y_tes,\n",
+ " label_a=\"Training Set\",\n",
+ " label_b=\"Validation Set\",\n",
+ " label_c=\"Testing Set\",\n",
+ " figsize=(12, 5),\n",
+ " alpha=0.5,\n",
+ " xlabel_plot=\"True class label\",\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 3: The histograms of true class labels for each data set. Each histogram is for a different data set used during model training --- training data, validation data, and test data. Note that these are overlapping histograms, not stacked. Please compare to Figure 4, which shows histograms of the predicted class labels."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 2.5. Create Dataloaders for training"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Set the batch size, which will be used in dataloaders and in training. To simplify data-handling, we set the batch size to the size of the subset that we select. This means that there is one batch used in training."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:57.305293Z",
+ "iopub.status.busy": "2025-02-09T23:00:57.304947Z",
+ "iopub.status.idle": "2025-02-09T23:00:57.453279Z",
+ "shell.execute_reply": "2025-02-09T23:00:57.452205Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:57.305259Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "batch_size = subset_size"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Create dataloaders for the training, validation, and test set data loaders."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:57.454746Z",
+ "iopub.status.busy": "2025-02-09T23:00:57.454386Z",
+ "iopub.status.idle": "2025-02-09T23:00:57.589881Z",
+ "shell.execute_reply": "2025-02-09T23:00:57.589215Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:57.454711Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "trainloader = torch.utils.data.DataLoader(data_tra, batch_size=batch_size,\n",
+ " shuffle=False)\n",
+ "\n",
+ "validloader = torch.utils.data.DataLoader(data_val, batch_size=batch_size,\n",
+ " shuffle=False)\n",
+ "\n",
+ "testloader = torch.utils.data.DataLoader(data_tes, batch_size=batch_size,\n",
+ " shuffle=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 3. Train the model: Convolutional Neural Network"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2024-12-15T21:16:11.080821Z",
+ "iopub.status.busy": "2024-12-15T21:16:11.079938Z",
+ "iopub.status.idle": "2024-12-15T21:16:11.085835Z",
+ "shell.execute_reply": "2024-12-15T21:16:11.085097Z",
+ "shell.execute_reply.started": "2024-12-15T21:16:11.080794Z"
+ }
+ },
+ "source": [
+ "In `pytorch`, neural network models are defined as classes. This is slightly different than typical `tensorflow` usage, in which people build a `sequential` model or use a pre-built `model` class and add layers to that model. \n",
+ "\n",
+ "The other major difference between `pytorch` and `tensorflow` is the shape of the tensors and the inputs for each layer. \n",
+ "\n",
+ "In particular, in `pytorch` one has to explicitly match the output from one layer to the input of the next layer. Sometimes, this can be done with a calculation. But, more often, one must perform guess-and-check. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3.1. Define model training hyperparameters"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Set the seed that for the random initial model weights."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:57.590774Z",
+ "iopub.status.busy": "2025-02-09T23:00:57.590515Z",
+ "iopub.status.idle": "2025-02-09T23:00:57.739275Z",
+ "shell.execute_reply": "2025-02-09T23:00:57.738219Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:57.590755Z"
+ },
+ "id": "Zsc_pqZtWftJ",
+ "outputId": "8238aa09-2afd-47ef-c93e-3910b342fe38"
+ },
+ "outputs": [],
+ "source": [
+ "seed = 1729\n",
+ "new = torch.manual_seed(seed)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Define the loss function. We choose cross-entropy, which is the standard loss function for classification of discrete labels."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:57.741224Z",
+ "iopub.status.busy": "2025-02-09T23:00:57.740823Z",
+ "iopub.status.idle": "2025-02-09T23:00:57.892170Z",
+ "shell.execute_reply": "2025-02-09T23:00:57.891536Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:57.741190Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "loss_fn = nn.CrossEntropyLoss()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3.2. Define the neural network model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2.1. Define model architecture elements"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Below is a list of important terms for elements of an architecture:\n",
+ "\n",
+ " * `activation function`: a function within a neuron that takes inputs from a previous layer and produces an output -- usually, this function is non-linear. Examples include \"sigmoid,\" \"softmax,\" and \"reLu\" (rectified linear unit).\n",
+ " * `sigmoid`: an activation function that takes points from the Real line and maps them to the range [-1,1].\n",
+ " * `softmax`: an activation function that takes points from the Real line and maps them to the range [0,1]. This is often used to obtain a 'probability score.'\n",
+ " * `weights`: the weight factor within an activation function.\n",
+ " * `biases`: the bias factor applied after an activation function.\n",
+ " * `layer`: one set of nodes/neurons that receive input data simultaneously.\n",
+ " * `linear (Dense) layer`: occurs due to the \"flattening\" of a higher-dimensional data vector, like an image. It only has an activation function -- as opposed to a convolutional layer which makes a convolution operation.\n",
+ " * `convolutional layer`: a layer that applies a convolution operation to an input sample."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2.2. Implement model architecture elements in an object class"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Typically, there are two functions within this class.\n",
+ "\n",
+ "The constructor function (`__init__`) defines the available layers of the network. These layers require specific settings related to the data set shapes. \n",
+ "* `Conv2d`: defines a two-dimensional convolutional layer. Four inputs are considered here:\n",
+ " * `in_channels` (required): the number of input channels.\n",
+ " * `out_channels` (required): the number of output channels.\n",
+ " * `kernel_size` (required): the size on one dimension of the convolutional kernel.\n",
+ " * `stride` (optional): the stride of the convolution\n",
+ "* `Dropout`: defines a dropout layer. The fraction of neuron weights that are set to zero. Typically, this is used only\n",
+ "* `Linear`: defines a linear layer. Two inputs are considered here:\n",
+ " * `in_features` (required): the size of the sample input to the layer.\n",
+ " * `out_features` (required): the size of the sample output from the layer.\n",
+ "\n",
+ "The function `forward` defines the order of operations during a forward pass of the model. During training, the `forward` function is applied to the input data to make predictions. After each round of predictions (epoch), the optimizer is engaged to take the difference between the true labels and the predicted labels and then use that difference to update the model weights. \n",
+ "\n",
+ "The `forward` function uses the layers defined in the constructor, as well as other layers that don't require inputs that depend on the data. These layers are defined in the `torch.nn.functional` submodule, which contains predefined functions for layers that operate directly on the data and don't require an instance of that layer. These `functional` layers are \n",
+ "* `relu`: applies the activation function, the `Rectified Linear Unit`. It requires one input, the sample from the previous layer.\n",
+ "* `max_pool2d`: the max pooling function is applied to the sample from the previous layer. It requires the input sample and the size of the kernel of the pooling.\n",
+ "* `flatten`: reshapes the sample input into a one-dimensional tensor."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2.3. Define the object class that represents the model. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:57.893245Z",
+ "iopub.status.busy": "2025-02-09T23:00:57.892802Z",
+ "iopub.status.idle": "2025-02-09T23:00:58.065150Z",
+ "shell.execute_reply": "2025-02-09T23:00:58.064028Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:57.893224Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "class ConvNet(nn.Module):\n",
+ " def __init__(self):\n",
+ " super(ConvNet, self).__init__()\n",
+ " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
+ " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
+ " self.dropout1 = nn.Dropout(0.25)\n",
+ " self.dropout2 = nn.Dropout(0.5)\n",
+ " self.fc1 = nn.Linear(9216, 128)\n",
+ " self.fc2 = nn.Linear(128, 10)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x = self.conv1(x)\n",
+ " x = F.relu(x)\n",
+ " x = self.conv2(x)\n",
+ " x = F.relu(x)\n",
+ " x = F.max_pool2d(x, 2)\n",
+ " x = self.dropout1(x)\n",
+ " x = torch.flatten(x, 1)\n",
+ " x = self.fc1(x)\n",
+ " x = F.relu(x)\n",
+ " x = self.dropout2(x)\n",
+ " x = self.fc2(x)\n",
+ " output = x\n",
+ " return output"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Instantiate a neural network model object."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:58.066569Z",
+ "iopub.status.busy": "2025-02-09T23:00:58.066197Z",
+ "iopub.status.idle": "2025-02-09T23:00:58.205918Z",
+ "shell.execute_reply": "2025-02-09T23:00:58.204825Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:58.066533Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "model = ConvNet()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Define the device. This is important especially when a GPU is available. This is necessary because the model and the data get moved to that device.\n",
+ "\n",
+ "In our case, the device is a CPU because that's what is currently available on the RSP."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:58.207635Z",
+ "iopub.status.busy": "2025-02-09T23:00:58.207117Z",
+ "iopub.status.idle": "2025-02-09T23:00:58.346979Z",
+ "shell.execute_reply": "2025-02-09T23:00:58.345787Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:58.207575Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "device = torch.device('cuda') if torch.cuda.is_available()\\\n",
+ " else torch.device('cpu')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Put the model on the device where the computations will be performed.\n",
+ "\n",
+ "When placing the model on the device, it also shows a summary of the network architecture. Examine the shapes of the layers and the numbers of parameters. Too few parameters may prevent the model from being flexible enough to model the data. Too many parameters could lead to overfitting of the model and a high computational cost."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:58.348374Z",
+ "iopub.status.busy": "2025-02-09T23:00:58.348049Z",
+ "iopub.status.idle": "2025-02-09T23:00:58.484854Z",
+ "shell.execute_reply": "2025-02-09T23:00:58.484286Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:58.348343Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Model Summary:\n",
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "ConvNet(\n",
+ " (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n",
+ " (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n",
+ " (dropout1): Dropout(p=0.25, inplace=False)\n",
+ " (dropout2): Dropout(p=0.5, inplace=False)\n",
+ " (fc1): Linear(in_features=9216, out_features=128, bias=True)\n",
+ " (fc2): Linear(in_features=128, out_features=10, bias=True)\n",
+ ")"
+ ]
+ },
+ "execution_count": 34,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "print(\"Model Summary:\\n\")\n",
+ "model.to(device)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2.4. Define hyperparameter terms for the training schedule"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "* `learning rate`: A multiplicative factor defining the amount that the model weights will change in response to the size of the error (loss) in the model prediction for that epoch. The lower the learning rate, the smaller the change to the weights, and usually the longer it will take to train the model. However, if the learning rate is too high, the weights can change too quickly: then, the loss can fluctuate significantly and not decrease quickly.\n",
+ "* `momentum`: A multiplicative factor on the aggregate of previous gradients. This aggregate term is combined with the weight gradient term to define the total change in the value of the weights. The smaller the momentum, the lesser the influence of the aggregated gradients, and usually the longer it will take train the model.\n",
+ "* `optimizer`: The method/algorithm used to update the network weights. Stochastic Gradient Descent (SGD) is the most commonly used method.\n",
+ "* `epoch`: One loop of training the model. Each loop includes the entire data set (all the batches) once, and it includes at least one round of weight updates.\n",
+ "* `n_epochs`: The number of epochs to train the network."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2.5. Assign hyperparameters for the training schedule"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:58.485804Z",
+ "iopub.status.busy": "2025-02-09T23:00:58.485463Z",
+ "iopub.status.idle": "2025-02-09T23:00:58.610914Z",
+ "shell.execute_reply": "2025-02-09T23:00:58.609882Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:58.485785Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "learning_rate = 0.01\n",
+ "momentum = 0.9\n",
+ "optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)\n",
+ "n_epochs = 50"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3.4. Train the model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Set the model into a training context. This ensures that layers like \"batchnorm\" and \"dropout\" will be activated. In contrast, when the model is set to an evaluation context (later in this tutorial), those layers will be deactivated."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:58.612344Z",
+ "iopub.status.busy": "2025-02-09T23:00:58.611972Z",
+ "iopub.status.idle": "2025-02-09T23:00:58.756735Z",
+ "shell.execute_reply": "2025-02-09T23:00:58.755676Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:58.612300Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "model.train()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use a loop over epochs, optimize the network weight parameters. \n",
+ "\n",
+ "Use lists to track the loss values on the training data, the loss values on the testing data, and the accuracy values on the testing data. Define the \"history\" dictionary to hold those lists. We will visualize these later to study the fitting efficacy and generalization capacity of the network."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:58.758266Z",
+ "iopub.status.busy": "2025-02-09T23:00:58.757903Z",
+ "iopub.status.idle": "2025-02-09T23:13:23.377226Z",
+ "shell.execute_reply": "2025-02-09T23:13:23.376192Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:58.758217Z"
+ },
+ "id": "zB1OY3o8VWF_",
+ "outputId": "7bdd7e85-cef2-47af-f16b-5f8bbd66f67f"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch ( 0): accuracy (9.78 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 1): accuracy (10.26 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 2): accuracy (10.44 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 3): accuracy (11.30 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 4): accuracy (12.58 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 5): accuracy (14.20 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 6): accuracy (16.96 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 7): accuracy (18.42 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 8): accuracy (19.68 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 9): accuracy (21.68 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 10): accuracy (23.16 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 11): accuracy (24.86 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 12): accuracy (24.88 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 13): accuracy (26.24 %), train loss (0.0005), valid loss (0.0004) \n",
+ "Epoch ( 14): accuracy (28.46 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 15): accuracy (28.84 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 16): accuracy (29.86 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 17): accuracy (32.38 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 18): accuracy (33.24 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 19): accuracy (36.54 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 20): accuracy (38.86 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 21): accuracy (39.54 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 22): accuracy (42.22 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 23): accuracy (44.86 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 24): accuracy (48.12 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 25): accuracy (49.74 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 26): accuracy (53.24 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 27): accuracy (53.98 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 28): accuracy (56.04 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 29): accuracy (56.62 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 30): accuracy (59.60 %), train loss (0.0004), valid loss (0.0003) \n",
+ "Epoch ( 31): accuracy (60.46 %), train loss (0.0003), valid loss (0.0003) \n",
+ "Epoch ( 32): accuracy (59.72 %), train loss (0.0003), valid loss (0.0003) \n",
+ "Epoch ( 33): accuracy (61.94 %), train loss (0.0003), valid loss (0.0003) \n",
+ "Epoch ( 34): accuracy (63.48 %), train loss (0.0003), valid loss (0.0003) \n",
+ "Epoch ( 35): accuracy (64.74 %), train loss (0.0003), valid loss (0.0002) \n",
+ "Epoch ( 36): accuracy (66.14 %), train loss (0.0002), valid loss (0.0002) \n",
+ "Epoch ( 37): accuracy (67.70 %), train loss (0.0002), valid loss (0.0002) \n",
+ "Epoch ( 38): accuracy (68.60 %), train loss (0.0002), valid loss (0.0002) \n",
+ "Epoch ( 39): accuracy (69.16 %), train loss (0.0002), valid loss (0.0002) \n",
+ "Epoch ( 40): accuracy (71.06 %), train loss (0.0002), valid loss (0.0002) \n",
+ "Epoch ( 41): accuracy (72.12 %), train loss (0.0002), valid loss (0.0002) \n",
+ "Epoch ( 42): accuracy (72.66 %), train loss (0.0002), valid loss (0.0002) \n",
+ "Epoch ( 43): accuracy (74.16 %), train loss (0.0002), valid loss (0.0002) \n",
+ "Epoch ( 44): accuracy (75.68 %), train loss (0.0002), valid loss (0.0001) \n",
+ "Epoch ( 45): accuracy (77.04 %), train loss (0.0002), valid loss (0.0001) \n",
+ "Epoch ( 46): accuracy (77.06 %), train loss (0.0001), valid loss (0.0001) \n",
+ "Epoch ( 47): accuracy (78.44 %), train loss (0.0001), valid loss (0.0001) \n",
+ "Epoch ( 48): accuracy (79.26 %), train loss (0.0001), valid loss (0.0001) \n",
+ "Epoch ( 49): accuracy (79.56 %), train loss (0.0001), valid loss (0.0001) \n",
+ "Total training time: 744.4417\n"
+ ]
+ }
+ ],
+ "source": [
+ "time_start = time.time()\n",
+ "\n",
+ "loss_train_list = []\n",
+ "loss_test_list = []\n",
+ "accuracy_test_list = []\n",
+ "\n",
+ "for epoch in np.arange(n_epochs):\n",
+ "\n",
+ " loss_train = 0\n",
+ " for inputs, labels in trainloader:\n",
+ " inputs = inputs.to(device)\n",
+ " labels = labels.to(device)\n",
+ " y_pred = model(inputs)\n",
+ " loss = loss_fn(y_pred, labels)\n",
+ " optimizer.zero_grad()\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ " loss_train += loss.item()\n",
+ "\n",
+ " loss_train /= batch_size\n",
+ "\n",
+ " accuracy_test = 0\n",
+ " loss_test = 0\n",
+ " count_test = 0\n",
+ " for inputs, labels in testloader:\n",
+ " inputs = inputs.to(device)\n",
+ " labels = labels.to(device)\n",
+ " y_pred = model(inputs)\n",
+ " loss = loss_fn(y_pred, labels)\n",
+ " accuracy_test += (torch.argmax(y_pred, 1) == labels).float().sum()\n",
+ " loss_test += loss.item()\n",
+ " count_test += len(labels)\n",
+ "\n",
+ " accuracy_test /= count_test\n",
+ " loss_test /= batch_size\n",
+ "\n",
+ " loss_train_list.append(loss_train)\n",
+ " loss_test_list.append(loss_test)\n",
+ " accuracy_test_list.append(accuracy_test)\n",
+ "\n",
+ " output = f\"Epoch ({epoch:3d}): accuracy ({accuracy_test*100:.2f} %),\\\n",
+ " train loss ({loss_train:.4f}), valid loss ({loss_test:.4f}) \"\n",
+ " print(output)\n",
+ "\n",
+ "time_end = time.time()\n",
+ "\n",
+ "time_difference = time_end - time_start\n",
+ "\n",
+ "history = {\"loss\": loss_train_list,\n",
+ " \"val_loss\": loss_test_list,\n",
+ " \"accuracy_test\": accuracy_test_list}\n",
+ "\n",
+ "print(f\"Total training time: {time_difference:2.4f}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Save model to file: \"pt\" is the common suffix used for pytorch model files. This saves the architecture and weights of the model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:23.378889Z",
+ "iopub.status.busy": "2025-02-09T23:13:23.378481Z",
+ "iopub.status.idle": "2025-02-09T23:13:23.551547Z",
+ "shell.execute_reply": "2025-02-09T23:13:23.550467Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:23.378864Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/home/bnord/dp02_16a_temp/Data/Models/Model_Run000_20250209_231323.pt\n"
+ ]
+ }
+ ],
+ "source": [
+ "file_prefix = path_dict['file_model_prefix'] + \"_\" + path_dict['run_label']\n",
+ "file_name_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_model'],\n",
+ " file_suffix=path_dict['file_model_suffix'],\n",
+ " useuid=True,\n",
+ " verbose=True)\n",
+ "\n",
+ "torch.save(model.state_dict(), file_name_final)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Load the model from a \"pt\" file."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:23.553120Z",
+ "iopub.status.busy": "2025-02-09T23:13:23.552767Z",
+ "iopub.status.idle": "2025-02-09T23:13:23.711074Z",
+ "shell.execute_reply": "2025-02-09T23:13:23.709943Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:23.553086Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "model.load_state_dict(torch.load(file_name_final, weights_only=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-01-14T23:05:09.388295Z",
+ "iopub.status.busy": "2025-01-14T23:05:09.387717Z",
+ "iopub.status.idle": "2025-01-14T23:05:09.481703Z",
+ "shell.execute_reply": "2025-01-14T23:05:09.481106Z",
+ "shell.execute_reply.started": "2025-01-14T23:05:09.388255Z"
+ }
+ },
+ "source": [
+ "Set the model to evaluation mode so that the \"batchnorm\" and \"dropout\" layers are deactivated. This is necessary for consistent inference."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:23.712652Z",
+ "iopub.status.busy": "2025-02-09T23:13:23.712268Z",
+ "iopub.status.idle": "2025-02-09T23:13:24.006950Z",
+ "shell.execute_reply": "2025-02-09T23:13:24.005916Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:23.712613Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "model.eval()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 4. Diagnosing the Results of Model Training"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.1. Key terms for diagnostics and model evaluation"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use the following diagnostics to assess the status of the network optimization and efficacy. The [scikit-learn page on metrics and scoring](https://scikit-learn.org/stable/modules/model_evaluation.html) provides a good in-depth reference for these terms.\n",
+ "\n",
+ "Model Predictions:\n",
+ " * `Classification threshold`: The user-chosen value $[0,1]$ that sets the threshold for a positive classification. \n",
+ " * `Probability score`: The output from the classifier neural network. One typically uses the `softmax` activation function in the last layer of the NN to provide an output in the range $[0,1]$.\n",
+ " * `Classification score`: The predicted class label is the class that received the highest probability score.\n",
+ "\n",
+ "Metrics:\n",
+ " * `Loss`: A function of the difference between the true labels and predicted labels.\n",
+ " * `Accuracy`: A rough indicator of model training progress/convergence for balanced datasets. For model performance, use only in combination with other metrics. Avoid this metric when you have unbalanced training datasets. Consider using another metric.\n",
+ " * `True Positive Rate (TPR; \"Recall\")`: Use when false negatives are more expensive than false positives.\n",
+ " * `False Positive Rate (FPR)`: Use when false positives are more expensive than false negatives.\n",
+ " * `Precision`: Use when positive predictions need to be accurate.\n",
+ "\n",
+ "The `Generalization Error` (GE) is the difference in loss when the model is applied to training data versus when applied to validation and test data.\n",
+ "\n",
+ "The `Confusion Matrix` is a visual representation of the classification accuracy. Each row is the set of predictions for each true value (with one true value per column). Values along the diagonal indicate true positives (correct predictions). Values below the diagonal indicate false positives. Values above the diagonal indicate false negatives. The optimal scenario is one in which the off-diagonal values are all zero.\n",
+ "\n",
+ "The `Receiver Operator Characteristic (ROC) Curve` curve presents a comparison between the true positive rate (y-axis) and the false positive rate (x-axis) --- for a given false positive rate, the number of true positives that exist. Each point on the curve is for a distinct choice of the Classification threshold for the probability score $[0,1]$: the choice of classification threshold determines which objects are considered correctly classified. The optimal scenario is where the ROC curve is constant at a true positive rate $=1$. If the curve is along the diagonal (lower left to upper right), it indicates that the model performance is equivalent to 50-50 guessing.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.2. Predict classifications with the trained model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Predict classification probabilities on the training, validation, and test sets. Produce both the probabilities of each digit and the top choice for each prediction."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:24.013950Z",
+ "iopub.status.busy": "2025-02-09T23:13:24.012874Z",
+ "iopub.status.idle": "2025-02-09T23:13:36.594110Z",
+ "shell.execute_reply": "2025-02-09T23:13:36.593536Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:24.013906Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "test data set ... Accuracy: 86.4%, Avg loss: 0.467137\n",
+ "validation data set ... Accuracy: 86.2%, Avg loss: 0.487235\n",
+ "training data set ... Accuracy: 86.0%, Avg loss: 0.487080\n"
+ ]
+ }
+ ],
+ "source": [
+ "y_prob_tes, y_choice_tes, y_true_tes, x_tes = predict(testloader, model,\n",
+ " \"test\")\n",
+ "y_prob_val, y_choice_val, y_true_val, x_val = predict(validloader, model,\n",
+ " \"validation\")\n",
+ "y_prob_tra, y_choice_tra, y_true_tra, x_tra = predict(trainloader, model,\n",
+ " \"training\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Print the shapes and verify that the shape of `y_prob_tes` matches the length of the input data `x_tes` and the number of classes. (as in Section 2.4)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:36.595037Z",
+ "iopub.status.busy": "2025-02-09T23:13:36.594822Z",
+ "iopub.status.idle": "2025-02-09T23:13:36.743092Z",
+ "shell.execute_reply": "2025-02-09T23:13:36.742096Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:36.595018Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The input data has the shape (5000, 28, 28): there the 5000 images each image has 28 pixels on a side.\n",
+ "The predicted probability score array has the shape (5000, 10): there are 5000 predictions, with 10 probability scores predicted for each input image.\n",
+ "The predicted classes array has has the shape (5000,): there is one top choice (highest probability score) for each prediction.\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(f\"The input data has the shape {np.shape(x_tes)}: there the {subset_size}\\\n",
+ " images each image has 28 pixels on a side.\")\n",
+ "print(f\"The predicted probability score array has the shape\\\n",
+ " {np.shape(y_prob_tes)}: there are {subset_size} predictions,\\\n",
+ " with 10 probability scores predicted for each input image.\")\n",
+ "print(f\"The predicted classes array has has the shape {np.shape(y_choice_tes)}:\\\n",
+ " there is one top choice (highest probability score) for each prediction.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use the `plotPredictionHistogram` function to plot histograms of prediction distributions by class."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 43,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:36.744923Z",
+ "iopub.status.busy": "2025-02-09T23:13:36.744545Z",
+ "iopub.status.idle": "2025-02-09T23:13:37.112883Z",
+ "shell.execute_reply": "2025-02-09T23:13:37.111880Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:36.744887Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"Histograms_top_choice\"\\\n",
+ " + \"_\" + path_dict['run_label']\n",
+ "\n",
+ "plotPredictionHistogram(y_choice_tra,\n",
+ " y_prediction_b=y_choice_val,\n",
+ " y_prediction_c=y_choice_tes,\n",
+ " label_a=\"Training Set\",\n",
+ " label_b=\"Validation Set\",\n",
+ " label_c=\"Testing Set\",\n",
+ " figsize=(12, 5),\n",
+ " alpha=0.5,\n",
+ " xlabel_plot=\"Predicted class label\",\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 4: Histograms of the number of images for which the top-choice class was each number 0 through 9. Each histogram is for a different data set used during model training --- training data, validation data, and test data. Note that these are overlapping histograms, not stacked. Please compare to Figure 3, which shows the distributions of true class labels for each data set. Consider which classes are represented differently between the true labels and the predicted labels."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 44,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:37.114356Z",
+ "iopub.status.busy": "2025-02-09T23:13:37.114005Z",
+ "iopub.status.idle": "2025-02-09T23:13:38.178779Z",
+ "shell.execute_reply": "2025-02-09T23:13:38.178135Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:37.114322Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"Histograms_class_probabilities\"\\\n",
+ " + \"_\" + path_dict['run_label']\n",
+ "\n",
+ "plotPredictionHistogram(y_prob_tra,\n",
+ " y_prediction_b=y_prob_val,\n",
+ " y_prediction_c=y_prob_tes,\n",
+ " title_a='Training Set',\n",
+ " title_b='Validation Set',\n",
+ " title_c='Testing Set',\n",
+ " figsize=(15, 4),\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 5: Histograms of the number of images (y-axis) that had a probability (x-axis) of being each class 0 through 9 (light to dark shades)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In both of Figure 3 and 4, the histograms show very similar shapes across the classification categories.\n",
+ "This is a good sign because it indicates the model is not heavily biased toward a particular class."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.3. Generalization Error\n",
+ "\n",
+ "The primary task in optimizing a network is to minimize the Generalization Error. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.3.1. Loss History: History of Loss and Accuracy during Training"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot the loss history for the validation and training sets. We reserve the test set for a 'blind' analysis."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 45,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:38.179887Z",
+ "iopub.status.busy": "2025-02-09T23:13:38.179456Z",
+ "iopub.status.idle": "2025-02-09T23:13:38.694470Z",
+ "shell.execute_reply": "2025-02-09T23:13:38.693818Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:38.179865Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"LossHistory\"\\\n",
+ " + \"_\"\\\n",
+ " + path_dict['run_label']\n",
+ "\n",
+ "plotLossHistory(history,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 6: In the top panel, the loss history as a function of epoch for the training and validation sets decreases with time, as it should as the model improves. In the bottom panel, the loss residual (validation - training) shows a dip at epoch 2, indicating the model caused a divergence in the training and validation set classifications, but that this was rectified in later epochs."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2024-10-26T18:50:41.751869Z",
+ "iopub.status.busy": "2024-10-26T18:50:41.751239Z",
+ "iopub.status.idle": "2024-10-26T18:50:41.757103Z",
+ "shell.execute_reply": "2024-10-26T18:50:41.756503Z",
+ "shell.execute_reply.started": "2024-10-26T18:50:41.751843Z"
+ }
+ },
+ "source": [
+ "### 4.3.2. Confusion Matrix: Bias in Trained Model?"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Compute and plot the confusion matrices for the training, validation, and test samples (left, right, middle)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 46,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:38.695466Z",
+ "iopub.status.busy": "2025-02-09T23:13:38.695205Z",
+ "iopub.status.idle": "2025-02-09T23:13:40.428933Z",
+ "shell.execute_reply": "2025-02-09T23:13:40.427883Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:38.695446Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')\n",
+ "figsize = (15, 3)\n",
+ "linewidths = 0.01\n",
+ "linecolor = 'white'\n",
+ "ylabel = \"Predicted Label\"\n",
+ "xlabel = \"True Label\"\n",
+ "\n",
+ "cm_tra = confusion_matrix(y_true_tra, y_choice_tra)\n",
+ "cm_val = confusion_matrix(y_true_val, y_choice_val)\n",
+ "cm_tes = confusion_matrix(y_true_tes, y_choice_tes)\n",
+ "\n",
+ "df_cm_tra = pd.DataFrame(cm_tra / np.sum(cm_tra, axis=1)[:, None],\n",
+ " index=[i for i in classes],\n",
+ " columns=[i for i in classes])\n",
+ "df_cm_val = pd.DataFrame(cm_val / np.sum(cm_val, axis=1)[:, None],\n",
+ " index=[i for i in classes],\n",
+ " columns=[i for i in classes])\n",
+ "df_cm_tes = pd.DataFrame(cm_tes / np.sum(cm_tes, axis=1)[:, None],\n",
+ " index=[i for i in classes],\n",
+ " columns=[i for i in classes])\n",
+ "\n",
+ "fig, (axa, axb, axc) = plt.subplots(1, 3, figsize=figsize)\n",
+ "fig.subplots_adjust(wspace=0.5)\n",
+ "\n",
+ "ax1 = sns.heatmap(df_cm_tra, annot=False, linewidths=linewidths,\n",
+ " linecolor=linecolor, square=True, ax=axa)\n",
+ "_ = ax1.set(xlabel=xlabel, ylabel=ylabel, title=\"Training Data\")\n",
+ "\n",
+ "ax2 = sns.heatmap(df_cm_val, annot=False, linewidths=linewidths,\n",
+ " linecolor=linecolor, square=True, ax=axb)\n",
+ "_ = ax2.set(xlabel=xlabel, ylabel=ylabel, title=\"Validation Data\")\n",
+ "\n",
+ "ax3 = sns.heatmap(df_cm_tes, annot=False, linewidths=linewidths,\n",
+ " linecolor=linecolor, square=True, ax=axc)\n",
+ "_ = ax3.set(xlabel=xlabel, ylabel=ylabel, title=\"Test Data\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 7: Confusion matrices for the training, validation, and test data sets (left to right). Each cell in the matrix shows the fraction (see the color bars) of the total objects with that true label that have been predicted to be a given label. The diagonal represents images that were correctly classified. All off-diagonal cells represent false positives (lower left) or false negatives (upper right). Consider some examples. First, the class \"0\" is almost always predicted to be \"0\" with a lighter color in the top-most, left-most cell; all the other cells in the column are completely dark. Second, consider the cell that represents a prediction \"4\", when the true label is \"9\": that cell is not completely dark. A \"4\" has a similar morphology or shape as a \"9.\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "vvAqrZwjVYBt"
+ },
+ "source": [
+ "### 4.3.4. Investigating predictions"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 4.3.4.1. Define of classification metrics"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-01-26T21:19:35.897282Z",
+ "iopub.status.busy": "2025-01-26T21:19:35.896488Z",
+ "iopub.status.idle": "2025-01-26T21:19:36.705672Z",
+ "shell.execute_reply": "2025-01-26T21:19:36.705022Z",
+ "shell.execute_reply.started": "2025-01-26T21:19:35.897242Z"
+ }
+ },
+ "source": [
+ "Consider the example of true class label being \"2\". Then, we define the following metrics.\n",
+ "\n",
+ "* `True Positive (TP)`: correctly classified input digit image --- e.g., a \"2\" classified as \"2\"\n",
+ "* `False Positive (FP)`: another digit classified as \"2\"\n",
+ "* `True Negative (TN)`: another digit classified as another digit); and\n",
+ "* `False Negative (FN)`: a \"2\" classified as something other than \"2\"."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 4.3.4.2. Explore the classification of the training data for an example class value"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Investigate the case in which the true digit label is \"2\"."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 47,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:40.430615Z",
+ "iopub.status.busy": "2025-02-09T23:13:40.430250Z",
+ "iopub.status.idle": "2025-02-09T23:13:40.575406Z",
+ "shell.execute_reply": "2025-02-09T23:13:40.574407Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:40.430564Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "class_value = 2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Find all objects that have that class value. Obtain indices for the tp's, fp's, tn's, and fn's. Create subsets of the data according to those indices."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:40.576471Z",
+ "iopub.status.busy": "2025-02-09T23:13:40.576232Z",
+ "iopub.status.idle": "2025-02-09T23:13:40.727095Z",
+ "shell.execute_reply": "2025-02-09T23:13:40.725923Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:40.576451Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "TP count: 439\n",
+ "FP count: 58\n",
+ "TN count: 4435\n",
+ "FN count: 68\n"
+ ]
+ }
+ ],
+ "source": [
+ "ind_class_tp_tra = np.where((y_true_tra == class_value)\n",
+ " & (y_choice_tra == class_value))[0]\n",
+ "\n",
+ "ind_class_fp_tra = np.where((y_true_tra != class_value)\n",
+ " & (y_choice_tra == class_value))[0]\n",
+ "\n",
+ "ind_class_tn_tra = np.where((y_true_tra != class_value)\n",
+ " & (y_choice_tra != class_value))[0]\n",
+ "\n",
+ "ind_class_fn_tra = np.where((y_true_tra == class_value)\n",
+ " & (y_choice_tra != class_value))[0]\n",
+ "\n",
+ "x_tra_tp = x_tra[ind_class_tp_tra]\n",
+ "y_true_tra_tp = y_true_tra[ind_class_tp_tra]\n",
+ "y_choice_tra_tp = y_choice_tra[ind_class_tp_tra]\n",
+ "\n",
+ "x_tra_fp = x_tra[ind_class_fp_tra]\n",
+ "y_true_tra_fp = y_true_tra[ind_class_fp_tra]\n",
+ "y_choice_tra_fp = y_choice_tra[ind_class_fp_tra]\n",
+ "\n",
+ "x_tra_tn = x_tra[ind_class_tn_tra]\n",
+ "y_true_tra_tn = y_true_tra[ind_class_tn_tra]\n",
+ "y_choice_tra_tn = y_choice_tra[ind_class_tn_tra]\n",
+ "\n",
+ "x_tra_fn = x_tra[ind_class_fn_tra]\n",
+ "y_true_tra_fn = y_true_tra[ind_class_fn_tra]\n",
+ "y_choice_tra_fn = y_choice_tra[ind_class_fn_tra]\n",
+ "\n",
+ "n_tp = len(ind_class_tp_tra)\n",
+ "n_fp = len(ind_class_fp_tra)\n",
+ "n_tn = len(ind_class_tn_tra)\n",
+ "n_fn = len(ind_class_fn_tra)\n",
+ "\n",
+ "print(f\"TP count: {n_tp:4d}\")\n",
+ "print(f\"FP count: {n_fp:4d}\")\n",
+ "print(f\"TN count: {n_tn:4d}\")\n",
+ "print(f\"FN count: {n_fn:4d}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:40.728669Z",
+ "iopub.status.busy": "2025-02-09T23:13:40.728301Z",
+ "iopub.status.idle": "2025-02-09T23:13:45.386481Z",
+ "shell.execute_reply": "2025-02-09T23:13:45.385906Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:40.728632Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "if n_tp > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"ExampleImages_TruePostives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayImageConfusion(x_tra_tp,\n",
+ " y_true_tra_tp,\n",
+ " y_choice_tra_tp,\n",
+ " title_main=\"True Positives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_fp > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"ExampleImages_FalsePostives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayImageConfusion(x_tra_fp,\n",
+ " y_true_tra_fp,\n",
+ " y_choice_tra_fp,\n",
+ " title_main=\"False Positives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_tn > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"ExampleImages_TrueNegatives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayImageConfusion(x_tra_tn,\n",
+ " y_true_tra_tn,\n",
+ " y_choice_tra_tn,\n",
+ " title_main=\"True Negatives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_fn > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"ExampleImages_FalseNegatives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayImageConfusion(x_tra_fn,\n",
+ " y_true_tra_fn,\n",
+ " y_choice_tra_fn,\n",
+ " title_main=\"False Negatives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 8: Four panels of 10 images each, representing true positives (top), false positives (second), true negatives (third), and false negatives (bottom), for classification category 2.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot histograms of images pixels of true positives, false positives, true negatives, and false negatives."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 50,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:45.387543Z",
+ "iopub.status.busy": "2025-02-09T23:13:45.387337Z",
+ "iopub.status.idle": "2025-02-09T23:13:50.807474Z",
+ "shell.execute_reply": "2025-02-09T23:13:50.806324Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:45.387525Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAuQAAAGMCAYAAACS67fPAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABVaklEQVR4nO3deXQUZdr+8avJTggEkrwJkSVAEFBWQZY4bMqqKIqorMMSEEVF1IC7CbK5DswgQWUwAYVhEeFVZBEFAg6iIDCy+FNUEBACsiZsCUme3x8M/dKEJUsl1Z18P+fUOfTT1dV3lZddd1dXVRxDph41AgAAAGCLMnYXAAAAAJRmNOQAAACAjbwvH5g+vJIddQAAAAClxtDEY85/c4QcAAAAsBENOQAAAGAjGnIAAADARjTkAAAAgI1oyAEAAAAb0ZADAAAANqIhB1BiOByOPE1r1qyxtc527drJ4XCoS5cuuZ7bs2ePHA6H3nrrLRsqy23nzp1KSEjQnj17cj03cOBARUVFFXtNAFDS5LoPOQB4qm+++cbl8dixY7V69WqtWrXKZfymm24qzrKuasWKFVq1apVuv/12u0u5qp07d2rMmDFq165drub75Zdf1pNPPmlPYQBQgtCQAygxWrZs6fI4LCxMZcqUyTV+uTNnzqhs2bJFWVouN954o7KysjR69Ght3LhRDoejWN/fCrVq1bK7BAAoEThlBUCp0q5dO9WvX19r165VTEyMypYtq8GDB0u6cMpLQkJCrtdERUVp4MCBLmOpqakaNmyYqlSpIl9fX9WoUUNjxoxRVlZWnurw8fHR+PHj9f3332vevHnXnT+v77d//3717NlTQUFBCg4OVt++fZ0Nf3JysnO+TZs2qVevXoqKilJAQICioqLUu3dv/f777855kpOT9cADD0iS2rdv7zzl5+JyLj9lpUmTJmrdunWu2rOzs3XDDTeoR48ezrHMzEyNGzdOdevWlZ+fn8LCwjRo0CD9+eefLq9dtWqV2rVrp5CQEAUEBKhatWq6//77debMmetuMwDwFBwhB1DqHDx4UP369dPo0aM1YcIElSmTv2MTqampat68ucqUKaNXXnlFtWrV0jfffKNx48Zpz549SkpKytNyHnroIb311lt66aWXdP/998vHx6dQ73f69Gm1b99ex44d0+uvv67o6GgtX75cDz30UK5l7tmzR3Xq1FGvXr1UqVIlHTx4UNOmTdOtt96qnTt3KjQ0VHfddZcmTJigF154QVOnTtUtt9wi6epHxgcNGqQnn3xSu3btUu3atZ3jX3zxhQ4cOKBBgwZJknJyctS9e3etW7dOo0ePVkxMjH7//XfFx8erXbt22rRpkwICArRnzx7dddddat26tT744AMFBwfrjz/+0PLly5WZmVnsv2oAQFGhIQdQ6hw7dkwLFiwo8LnbCQkJOn78uHbs2KFq1apJku644w4FBAQoLi5Oo0aNytN56g6HQ6+//ro6dOig9957T48//nih3m/mzJn65ZdftGzZMucFo506ddKZM2f03nvvuSyzZ8+e6tmzp/Nxdna2unXrpvDwcM2ZM0cjRoxQWFiYs7G+6aabrnvqT9++fTVq1CglJydr/PjxzvHk5GSFh4era9eukqT58+dr+fLlWrhwoctR80aNGunWW29VcnKyHn30UX3//fc6d+6c3nzzTTVq1Mg5X58+fa67bQHAk3DKCoBSp2LFioW6kHLJkiVq3769IiMjlZWV5ZwuNpwpKSl5XtYdd9yhTp066dVXX1V6enqh3i8lJUVBQUG57t7Su3fvXMs8deqUnn32WUVHR8vb21ve3t4qV66cTp8+rR9//DHP9V8qJCREd999t2bOnKmcnBxJ0vHjx/W///u/+utf/ypvb2/n+gQHB+vuu+92WZ/GjRsrIiLCeRecxo0by9fXVw8//LBmzpyp3377rUB1AYC7oyEHUOpUrly5UK8/dOiQPvvsM/n4+LhMN998syTpyJEj+Vre66+/riNHjlz1Vod5fb+jR48qPDw81+uvNNanTx+98847GjJkiFasWKHvvvtOGzduVFhYmM6ePZuv+i81ePBg/fHHH1q5cqUk6V//+pcyMjJczsE/dOiQTpw4IV9f31zrlJqa6lyfWrVq6csvv9T//M//6LHHHlOtWrVUq1Yt/f3vfy9wfQDgjjhlBUCpc7U7mvj5+SkjIyPX+NGjR10eh4aGqmHDhi6nZVwqMjIyX/U0btxYvXv31t/+9jfdeeeduZ7P6/uFhITou+++y/V8amqqy+OTJ09qyZIlio+P13PPPeccz8jI0LFjx/JV++U6d+6syMhIJSUlqXPnzkpKSlKLFi1cTuEJDQ1VSEiIli9ffsVlBAUFOf/dunVrtW7dWtnZ2dq0aZOmTJmikSNHKjw8XL169SpUrQDgLmjIAeC/oqKi9MMPP7iMrVq1SqdOnXIZ69atm5YuXapatWqpYsWKlrz3uHHj9PHHH2vMmDG5nsvr+7Vt21bz58/XsmXLnKezSNLcuXNd5nM4HDLGyM/Pz2X8n//8p7Kzs13GLs6T16PmXl5e6t+/vyZPnqx169Zp06ZNuc5f79atm+bOnavs7Gy1aNEiz8tt0aKF6tatq9mzZ2vz5s005ABKDBpyAPiv/v376+WXX9Yrr7yitm3baufOnXrnnXdUoUIFl/leffVVrVy5UjExMRoxYoTq1Kmjc+fOac+ePVq6dKneffddValSJV/vXaNGDT366KNXPB0jr+83YMAATZo0Sf369dO4ceMUHR2tZcuWacWKFZLkvJtM+fLl1aZNG7355psKDQ1VVFSUUlJSNGPGDAUHB7u8d/369SVJ77//voKCguTv768aNWooJCTkqusyePBgvf766+rTp48CAgJy3eWlV69emj17tu688049+eSTat68uXx8fLR//36tXr1a3bt313333ad3331Xq1at0l133aVq1arp3Llz+uCDDyRJHTp0yNf2BQB3xjnkAPBfo0aNct4l5O6779bChQs1f/78XE1q5cqVtWnTJnXq1ElvvvmmunTpov79++uDDz5Q48aNC3zU/KWXXlL58uVzjef1/QIDA5337R49erTuv/9+7d27V4mJiZLksh5z5sxR+/btNXr0aPXo0UObNm3SypUrc335qFGjhiZPnqz//Oc/ateunW699VZ99tln11yPG2+8UTExMdq/f7969OiRa5leXl769NNP9cILL+iTTz7Rfffdp3vvvVevvfaa/P391aBBA0kXTuXJyspSfHy8unbtqv79++vPP//Up59+qk6dOuV7+wKAu3IMmXrUXDowfXglu2oBABSBCRMm6KWXXtLevXvzfeQeAFA0hib+3zU7nLICACXIO++8I0mqW7euzp8/r1WrVukf//iH+vXrRzMOAG6KhhwASpCyZctq0qRJ2rNnjzIyMlStWjU9++yzeumll+wuDQBwFTTkAFCCDB48WIMHD7a7DABAPnBRJwAAAGAjGnIAAADARjTkAAAAgI1oyAEAAAAb0ZADAAAANqIhBwAAAGxEQw4AAADYiIYcAAAAsBENOQAAAGAjGnIAAADARjTkAAAAgI1oyAEAAAAb0ZADAAAANqIhBwAAAGxEQw4AAADYiIYcAAAAsBENOQAAAGAjGnIAAADARjTkAAAAgI1oyAEAAAAb0ZADAAAANqIh/y+Hw5Gnac2aNXaXelVHjhzRk08+qaioKPn5+Sk8PFxdu3bVsWPH7C6tVPH0LKWlpenFF1/UjTfeqLJly+qGG27QAw88oB07dthdWqnj6VmSpLlz56px48by9/dXZGSkRo4cqVOnTtldVqlDlmAVT8/SrFmz1KtXL9WpU0dlypRRVFSU3SVJkrztLsBdfPPNNy6Px44dq9WrV2vVqlUu4zfddFNxlpVnBw4cUOvWreXt7a2XX35ZtWvX1pEjR7R69WplZmbaXV6p4ulZuvvuu7Vp0yYlJCSoWbNm2r9/v1599VW1atVK27ZtU/Xq1e0usdTw9CzNnj1b/fr105AhQzRp0iT9/PPPevbZZ7Vz50598cUXdpdXqpAlWMXTs/Thhx8qNTVVzZs3V05Ojs6fP293SZJoyJ1atmzp8jgsLExlypTJNX65M2fOqGzZskVZWp4MHz5cGRkZ2rRpkypWrOgc79Gjh41VlU6enKVffvlFa9eu1UsvvaRRo0Y5x6OjoxUTE6NPPvlETz31lI0Vli6enKXs7GyNGjVKnTp10vTp0yVJ7du3V1BQkPr27atly5apa9euttZYmpAlWMWTsyRJK1asUJkyF04Q6datm7Zv325zRRdwyko+tGvXTvXr19fatWsVExOjsmXLavDgwZIu/ISTkJCQ6zVRUVEaOHCgy1hqaqqGDRumKlWqyNfXVzVq1NCYMWOUlZVVoLr27NmjTz/9VEOHDnVpxuG+3DVLPj4+kqQKFSq4jAcHB0uS/P39C7RcFB13zdKGDRt08OBBDRo0yGX8gQceULly5bRo0aICLRdFhyzBKu6aJUnOZtzduGdVbuzgwYPq16+f+vTpo6VLl2r48OH5ev3Fn0lWrFihV155RcuWLVNsbKwmTpyooUOHusw7cOBAORwO7dmz55rLXLdunYwxioyMVO/evVWuXDn5+/urXbt2uX5agvtwxyxVr15d3bt316RJk7R69WqdOnVK/+///T+NGDFC1apVU69evfK7migG7pili0edGjZs6DLu4+OjunXrus1RKbgiS7CKO2bJnXHKSj4dO3ZMCxYs0O23316g1yckJOj48ePasWOHqlWrJkm64447FBAQoLi4OI0aNcp53pWXl5e8vLzkcDiuucw//vhDkhQXF6f27dtr4cKFOn36tMaMGaPbb79d3377ba4PMtjPHbMkSQsWLNBjjz3mUlfDhg2VkpLCLzBuyh2zdPToUUlSpUqVcj1XqVIlj95xlmRkCVZxxyy5M46Q51PFihULHC5JWrJkidq3b6/IyEhlZWU5p4vnv6WkpDjnnTFjhrKysq57EV1OTo4kqUqVKlq4cKE6d+6sHj16aPny5SpTpozeeOONAteLouOOWZKkRx99VAsXLtSkSZOUkpKiefPmydfXV7fffrt+//33AteLouOuWZJ01R2kJ+84SzKyBKu4c5bcEUfI86ly5cqFev2hQ4f02WefOc/VvdyRI0fyvcyQkBBJUocOHeTl5eUcr1y5sho1aqTNmzcXrFgUKXfM0vLlyzVjxgwtWLBAPXv2dI536tRJUVFRSkhIUFJSUoFrRtFwxyxd/Fw6evSowsPDXZ47duzYFY92wn5kCVZxxyy5MxryfLraN3E/Pz9lZGTkGr/4U9tFoaGhatiwocaPH3/F5URGRua7pmudjmKMcdsLGEo7d8zS1q1bJUm33nqry3hwcLCio6M5V9NNuWOWGjRoIEnatm2by+3PsrKy9P/+3/9T7969871MFD2yBKu4Y5bcGQ25RaKiovTDDz+4jK1atSrXHy3o1q2bli5dqlq1all2Pm6LFi1UpUoVffHFF8rOznYeJT9w4ID+85//qE+fPpa8D4qHnVm6+AG3YcMGl5/+jh49qp9//ll33HGHJe+D4mH351LlypWVnJyshx56yDn+8ccf69SpU9yS1cOQJVjFziy5Mw6dWqR///5atmyZXnnlFX311VeaMmWKHn300Vy3j3v11Vfl4+OjmJgYTZs2TatWrdLSpUuVmJiobt26af/+/c55Y2Nj5e3tfd3zdsuUKaNJkybpp59+Uvfu3fX5559r/vz56ty5s3x9ffX8888XyTqjaNiZpR49eqh69ep69NFH9fbbb2v16tWaM2eOOnTooDNnzujJJ58sknVG0bAzS15eXnrjjTe0fPlyDRs2TGvWrNH06dP16KOPqmPHjurSpUuRrDOKBlmCVezMkiTt3LlTH3/8sT7++GOlpqbqzJkzzsc7d+60fH3ziiPkFhk1apTS0tKUnJyst956S82bN9f8+fPVvXt3l/kqV66sTZs2aezYsXrzzTe1f/9+BQUFqUaNGurSpYvLt8Ds7GxlZ2fLGHPd9+/Zs6cWLVqk8ePHq2fPnvLz81Pbtm01b9481apVy/L1RdGxM0vlypXThg0bNH78eL377rvav3+/KlWqpCZNmmjatGnX/cMPcC92fy7169dPXl5eeu2115ScnKxKlSrpr3/961V/gob7Ikuwit1Zmj9/vsaMGeMy9sADD0iS4uPjr3iP9OLgGDL1qEv104dzcQQAAABQlIYmHnP+m1NWAAAAABvRkAMAAAA2oiEHAAAAbERDDgAAANiIhvy/kpOT5XA4nJO3t7eqVKmiQYMG6Y8//iiWGqKiojRw4MBCLWP79u164IEHFBYWJj8/P0VFRWn48OHWFIg8KQlZ+uWXX9S/f39Vq1ZNAQEBqlWrlp5++ulcf7gBRaskZOn8+fMaM2aMoqKi5Ofnp7p162rKlCnWFYg8KQlZktjHuYOSkKWXXnpJ3bp10w033CCHw1HoXFqB2x5eJikpSXXr1tXZs2e1du1aTZw4USkpKdq2bZsCAwPtLu+aVq9erbvuukutW7fWu+++q9DQUO3du1dbtmyxu7RSyVOz9Oeff6ply5YqX768xo4dq2rVqmnLli2Kj4/X6tWr9f333/PXX4uZp2ZJkoYPH64PP/xQY8eO1a233qoVK1boySefVHp6ul544QW7yyt1PDlL7OPciydnadKkSWrYsKHuueceffDBB3aXI4mGPJf69eurWbNmkqT27dsrOztbY8eO1eLFi9W3b98rvubMmTMqW7ZscZZ5xRr69u2r22+/XZ999pnLn6zt37+/jZWVXp6apf/93//V0aNHNW/ePOdf5mzfvr0yMjL0wgsv6D//+Y+aNGlia42ljadmaceOHZoxY4bGjx+vUaNGSZLatWuno0ePaty4cXrkkUdUqRK32i1Onpol9nHux1OzJEnp6enOA0sffvihzdVcwGGu67j4h1Au/vWngQMHqly5ctq2bZs6deqkoKAgZ9OSmZmpcePGqW7duvLz81NYWJgGDRqkP//802WZ58+f1+jRoxUREaGyZcvqL3/5i7777rtC1blgwQIdPHhQo0aNcvmggvvwlCz5+PhIUq6/mhYcHCxJ8vf3L9TyUXiekqXFixfLGKNBgwa5jA8aNEhnz57V8uXLC7V8FJ6nZIl9nPvzlCxJcstfeTlCfh2//PKLJCksLMw5lpmZqXvuuUfDhg3Tc889p6ysLOXk5Kh79+5at26dRo8erZiYGP3++++Kj49Xu3bttGnTJgUEBEiShg4dqlmzZikuLk4dO3bU9u3b1aNHD6Wnp+d6/6ioKEnSnj17rlnn2rVrJV34a1UXAxsYGKguXbro7bffVmRkpAVbA4XhKVm69957Va1aNT3zzDNKTExU9erVtXnzZr322mu6++67Va9ePWs2CArMU7K0fft2hYWFKSIiwmW8YcOGzudhL0/JEvs49+cpWXJbQ6YeNZdOpVVSUpKRZDZs2GDOnz9v0tPTzZIlS0xYWJgJCgoyqampxhhjBgwYYCSZDz74wOX1//rXv4wks3DhQpfxjRs3GkkmMTHRGGPMjz/+aCSZp556ymW+2bNnG0lmwIABLuO1atUytWrVum79nTt3NpJMcHCwGT16tFm1apV59913TUhIiImOjjanT5/O7yZBAXl6lowx5sCBA6ZVq1ZGknN64IEHzLlz5/KzKVBInp6ljh07mjp16lzxOV9fX/Pwww9fdxmwhqdniX2c+/D0LF0uMDAw17KKy6X9t/sds7dZy5Yt5ePjo6CgIHXr1k0RERFatmyZwsPDXea7//77XR4vWbJEwcHBuvvuu5WVleWcGjdurIiICK1Zs0bShYtSJOU6v+rBBx+Ut3fuHyx++eUX57fOa8nJyZEkPfTQQ3r99dfVvn17DRs2TDNmzNAvv/yiOXPm5HkbwBqemqXjx4+re/fuSktL0+zZs7V27VolJibq66+/1j333KOsrKz8bAZYwFOzJOmapxdw6kHx89QssY9zP56aJXfFKSuXmTVrlurVqydvb2+Fh4ercuXKueYpW7asypcv7zJ26NAhnThxQr6+vldc7pEjRyTJedu4y3/C9fb2VkhISIHrvvjazp07u4x37txZDodDmzdvLvCyUTCemqXXX39dW7du1e+//+6suXXr1qpbt65uv/12zZ49WwMGDCjw8pF/npqlkJAQbd26Ndf46dOnlZmZyQWdNvDkLEns49yJp2bJXdGQX6ZevXrOq4av5kpHdUJDQxUSEnLVi5SCgoIk/d+HSmpqqm644Qbn81lZWYW6x3PDhg01d+7cqz7vjhcwlHSemqWtW7fqhhtuyPXheuutt0rivF87eGqWGjRooLlz5yo1NdVlp7pt2zZJF+7SgOLlqVliH+d+PDVL7ooEW6Rbt246evSosrOz1axZs1xTnTp1JF245ZckzZ492+X18+fPL9SpAPfdd58cDoeWLVvmMr5s2TIZY5xXP8P92Z2lyMhI7d+/P9cfePjmm28kSVWqVCnwslG87M5S9+7d5XA4NHPmTJfx5ORkBQQEqEuXLgVeNoqX3VliH1dy2J0ld8URcov06tVLs2fP1p133qknn3xSzZs3l4+Pj/bv36/Vq1ere/fuuu+++1SvXj3169dPkydPlo+Pjzp06KDt27frrbfeyvWzjiRFR0dL0nXPi6pbt64ee+wxJSYmKigoSF27dtXPP/+sl156SU2aNNGDDz5YJOsN69mdpccee0yzZ89Wx44d9dxzz6lq1aravn27xo0bp/Dw8KveXxbux+4s3XzzzYqNjVV8fLy8vLx066236osvvtD777+vcePGccqKB7E7S+zjSg67syRJKSkpzlssZmdn6/fff9fHH38sSWrbtq3LnWKKDXdZueDiVcMbN2685nwDBgwwgYGBV3zu/Pnz5q233jKNGjUy/v7+ply5cqZu3bpm2LBhZteuXc75MjIyzDPPPGP+53/+x/j7+5uWLVuab775xlSvXj3Xlb7Vq1c31atXz9M6ZGVlmddee81ER0cbHx8fU7lyZfPoo4+a48eP5+n1sEZJyNLmzZvNfffdZ6pUqWL8/PxMzZo1zZAhQ8zevXvz9HpYoyRkKTMz08THx5tq1aoZX19fc+ONN5p//OMfeXotrFMSssQ+zj2UhCy1bdvW5S5il06rV6/O0zKscGn/7Rgy9ai5tEGfPpwjFgAAAEBRGpp4zPlvziEHAAAAbERDDgAAANiIhhwAAACwEQ05AAAAYCMacgAAAMBGNOQAAACAjTzyDwPl5OTowIEDCgoKuuKfZUXRMsYoPT1dkZGRHv/nismSvcgSrEKWYJWSlCWJPNkpP1nyyIb8wIEDqlq1qt1llHr79u3z+D+jTpbcA1mCVcgSrFISsiSRJ3eQlyx5ZEMeFBQk6cIKXunPp6JopaWlqWrVqs7/Dp6MLNmLLMEqZAlWKUlZksiTnfKTJY9syC/+5FK+fHnCZaOS8NMXWXIPZAlWIUuwSknIkkSe3EFesuT5J0cBAAAAHswjj5BfzZgxY675fHx8fDFVAk9HlmAVsgSrXC9LEnlC3pAl98MRcgAAAMBGNOQAAACAjWjIAQAAABvRkAMAAAA2oiEHAAAAbERDDgAAANiIhhwAAACwEQ05AAAAYCMacgAAAMBGNOQAAACAjWjIAQAAABvRkMMWCQkJcjgcLlNERITzeWOMEhISFBkZqYCAALVr1047duxwWUZGRoaeeOIJhYaGKjAwUPfcc4/2799f3KsCm5ElAICnoyGHbW6++WYdPHjQOW3bts353BtvvKG//e1veuedd7Rx40ZFRESoY8eOSk9Pd84zcuRILVq0SHPnztXXX3+tU6dOqVu3bsrOzrZjdWAjsgQA8GTedheA0svb29vlSOZFxhhNnjxZL774onr06CFJmjlzpsLDwzVnzhwNGzZMJ0+e1IwZM/Thhx+qQ4cOkqSPPvpIVatW1ZdffqnOnTtf8T0zMjKUkZHhfJyWllYEa4biRpYAAJ6MI+Swza5duxQZGakaNWqoV69e+u233yRJu3fvVmpqqjp16uSc18/PT23bttX69eslSd9//73Onz/vMk9kZKTq16/vnOdKJk6cqAoVKjinqlWrFtHaoTiRJQCAJ6Mhhy1atGihWbNmacWKFZo+fbpSU1MVExOjo0ePKjU1VZIUHh7u8prw8HDnc6mpqfL19VXFihWvOs+VPP/88zp58qRz2rdvn8VrhuJGlmAVrkcAYJd8NeR8WMEqXbt21f33368GDRqoQ4cO+vzzzyVdOJ3gIofD4fIaY0yusctdbx4/Pz+VL1/eZYJnI0uwEtcjALBDvo+Q82GFohAYGKgGDRpo165dzi95lx+dPHz4sPNIZ0REhDIzM3X8+PGrzoPSiSyhMC5ej3BxCgsLk5T7eoT69etr5syZOnPmjObMmSNJzusR3n77bXXo0EFNmjTRRx99pG3btunLL7+0c7VQzDiAifzKd0Nux4dVRkaG0tLSXCaULBkZGfrxxx9VuXJl1ahRQxEREVq5cqXz+czMTKWkpCgmJkaS1LRpU/n4+LjMc/DgQW3fvt05D0onsoTCsON6BPZxJRMHMJEf+W7IuXgKVoiLi1NKSop2796tb7/9Vj179lRaWpoGDBggh8OhkSNHasKECVq0aJG2b9+ugQMHqmzZsurTp48kqUKFCoqNjdUzzzyjr776Slu2bFG/fv2cpy2g9CBLsIpd1yOwjyuZ+LUF+ZGv2x5e/LC68cYbdejQIY0bN04xMTHasWPHNT+sfv/9d0mFu3jq6aefdj5OS0vjA8vD7d+/X71799aRI0cUFhamli1basOGDapevbokafTo0Tp79qyGDx+u48ePq0WLFvriiy8UFBTkXMakSZPk7e2tBx98UGfPntUdd9yh5ORkeXl52bVasAFZglW6du3q/HeDBg3UqlUr1apVSzNnzlTLli0lFc31COzjSqaLBzD9/PzUokULTZgwQTVr1rzuAcxhw4Zd9wDm1W7HKnFLVk+Vr4bcrg8rPz8/+fn55adUuLm5c+de83mHw6GEhAQlJCRcdR5/f39NmTJFU6ZMsbg6eBKyhKJy6fUI9957r6QLB5YqV67snOdq1yNceuDp8OHD1zz9iX1cyWPXAUzpwi8uY8aMsXBtUBwKddtDLp4CAJRUXI+AgrLr7k8St2T1VIVqyPmwAgCUFFyPgKJSnAcwuSWrZ8pXQ86HFQCgpLp4PUKdOnXUo0cP+fr65roeYeTIkRo+fLiaNWumP/7444rXI9x777168MEHddttt6ls2bL67LPPuB6hlOMAJq4nX+eQc/EUAKCk4noEWCUuLk533323qlWrpsOHD2vcuHFXPIBZu3Zt1a5dWxMmTLjqAcyQkBBVqlRJcXFxHMAswfLVkPNhBQAAcG0cwER+5ashBwAAwLVxABP5VaiLOgEAAAAUDg05AAAAYCMacgAAAMBGNOQAAACAjWjIAQAAABvRkAMAAAA2oiEHAAAAbERDDgAAANiIhhwAAACwEQ05AAAAYCMacgAAAMBGNOQAAACAjWjIAQAAABvRkAMAAAA2oiEHAAAAbERDDgAAANiIhhwAAACwEQ05AAAAYCMacgAAAMBGNOQAAACAjWjIAQAAABvRkAMAAAA2oiEHAAAAbERDDgAAANiIhhwAAACwEQ05AAAAYCMacgAAAMBGNOQAAACAjWjIAQAAABvRkAMAAAA2oiEHAAAAbERDDgAAANiIhhwAAACwEQ05AAAAYCMacgAAAMBGNOQAAACAjWjIAQAAABvRkAMAAAA2oiEHAAAAbORtdwEAAFxuzJgx150nPj6+GCoBgKJHQw7AMjRRAADkH6esAAAAADaiIQcAAABsxCkruO5pBpxiAADwVJxKB09AQw4AAABcR1F+ubP1lJXExETVqFFD/v7+atq0qdatW2dnOfBgZAlWIUuwClmCVchSyWfbEfJ58+Zp5MiRSkxM1G233ab33ntPXbt21c6dO1WtWjW7yoIHIkvXx2lJeWNHltzp5/TiqiUv7+Pp+FyyBp9dZKm0sK0h/9vf/qbY2FgNGTJEkjR58mStWLFC06ZN08SJE13mzcjIUEZGhvPxyZMnJUlpaWku8507d+6a7/n8889bUboly7l8HQvyPnlZhhUu384XHxtjiuX9r8eTs+QuinN9Lt3WZClvLl/mldj1eVAQVmwTiSxdiRX/fYorS1aw6rOrpGRJylue8pKl623bvGx7K3qd4mJV7gucpSFTj5pLp+KQkZFhvLy8zCeffOIyPmLECNOmTZtc88fHxxtJTG427du3r1jyci1kqWRMZImJLNm/7Zg8P0vGkCd3nK6WpUv7b1uOkB85ckTZ2dkKDw93GQ8PD1dqamqu+Z9//nk9/fTTzsc5OTk6duyYQkJC5HA4JF34FlK1alXt27dP5cuXL9oVKEWutF2NMUpPT1dkZKTN1ZElT3P5tiVLKCiyBKuUpCxJ188TWSo6hcmSrXdZufhBc5ExJteYJPn5+cnPz89lLDg4+IrLLF++PAErApdv1woVKthYTW5kybNcum3JEgqDLMEqJSFLUt7zRJaKTkGyZMtdVkJDQ+Xl5ZXr293hw4dzfQsEroUswSpkCVYhS7AKWSo9bGnIfX191bRpU61cudJlfOXKlYqJibGjJHgosgSrkCVYhSzBKmSp9LDtlJWnn35a/fv3V7NmzdSqVSu9//772rt3rx555JECLc/Pz0/x8fG5fqZB4XjCdiVLnsPdty1Z8hzuvm3Jkudw921LljxHYbatY8jUo+bSgenDK1lW2PUkJibqjTfe0MGDB1W/fn1NmjRJbdq0Kbb3R8lBlmAVsgSrkCVYhSyVTEMTjzn/bWtDDgAAAJRGlzbktpxDDgAAAOACGnIAAADARjTkAAAAgI1oyAEAAAAblYiGPDExUTVq1JC/v7+aNm2qdevW2V1SibB27VrdfffdioyMlMPh0OLFi+0uqciRpaJBlsiSVcgSWbIKWSJLVrEiSx7fkM+bN08jR47Uiy++qC1btqh169bq2rWr9u7da3dpHu/06dNq1KiR3nnnHbtLKRZkqeiQJbJkFbJElqxClsiSVazIksff9rBFixa65ZZbNG3aNOdYvXr1dO+992rixIk2VlayOBwOLVq0SPfee6/dpRQZslQ8yBJZsgpZIktWIUtkySr5yVKJue1hZmamvv/+e3Xq1MllvFOnTlq/fr1NVcETkSVYhSzBKmQJViFL7s+jG/IjR44oOztb4eHhLuPh4eFKTU21qSp4IrIEq5AlWIUswSpkyf15dEN+kcPhcHlsjMk1BuQFWYJVyBKsQpZgFbLkvjy6IQ8NDZWXl1eub3eHDx/O9S0QuBayBKuQJViFLMEqZMn9eXRD7uvrq6ZNm2rlypUu4ytXrlRMTIxNVcETkSVYhSzBKmQJViFL7s/b7gIK6+mnn1b//v3VrFkztWrVSu+//7727t2rRx55xO7SPN6pU6f0yy+/OB/v3r1bW7duVaVKlVStWjUbKysaZKnokCWyZBWyRJasQpbIklUsydKQqUfNpZMnmjp1qqlevbrx9fU1t9xyi0lJSbG7pBJh9erVRlKuacCAAXaXVmTIUtEgS2TJKmSJLFmFLJElqxQ0S5f23x5/H3IAAADA05SY+5ADAAAAno6GHAAAALARDTkAAABgIxpyAAAAwEY05AAAAICNaMgBAAAAG9GQAwAAADYqtQ15QkKCGjdubNny1qxZI4fDoRMnTli2zCtxOBxavHhxkb4H8ocswSpkCVYhS7AKWSoeJbYhHzhwoBwOhxwOh3x8fFSzZk3FxcXp9OnTkqS4uDh99dVXxVJLZmamQkNDNW7cuCs+P3HiRIWGhiozM7NY6kH+kCVYhSzBKmQJViFL7qHENuSS1KVLFx08eFC//fabxo0bp8TERMXFxUmSypUrp5CQkGKpw9fXV/369VNycrKMMbmeT0pKUv/+/eXr61ss9SD/yBKsQpZgFbIEq5Al+5XohtzPz08RERGqWrWq+vTpo759+zp/vrj0J5hz587p5ptv1sMPP+x87e7du1WhQgVNnz5dkmSM0RtvvKGaNWsqICBAjRo10scff5znWmJjY/Xrr79q7dq1LuPr1q3Trl27FBsbq40bN6pjx44KDQ1VhQoV1LZtW23evPmqy7zSzz5bt26Vw+HQnj17nGPr169XmzZtFBAQoKpVq2rEiBHOb76SlJiYqNq1a8vf31/h4eHq2bNnntertCBLF5ClwiNLF5ClwiNLF5ClwiNLF9iZpRLdkF8uICBA58+fzzXu7++v2bNna+bMmVq8eLGys7PVv39/tW/fXkOHDpUkvfTSS0pKStK0adO0Y8cOPfXUU+rXr59SUlLy9N4NGjTQrbfeqqSkJJfxDz74QM2bN1f9+vWVnp6uAQMGaN26ddqwYYNq166tO++8U+np6QVe523btqlz587q0aOHfvjhB82bN09ff/21Hn/8cUnSpk2bNGLECL366qv66aeftHz5crVp06bA71dakCWyZBWyRJasQpbIklXIkg1ZGjL1qLl0KikGDBhgunfv7nz87bffmpCQEPPggw8aY4yJj483jRo1cnnNG2+8YUJDQ80TTzxhIiIizJ9//mmMMebUqVPG39/frF+/3mX+2NhY07t3b2OMMatXrzaSzPHjx69a07Rp00xgYKBJT083xhiTnp5uAgMDzXvvvXfF+bOyskxQUJD57LPPnGOSzKJFi676nlu2bDGSzO7du40xxvTv3988/PDDLstdt26dKVOmjDl79qxZuHChKV++vElLS7tq3aUdWdptjCFLViBLu40xZMkKZGm3MYYsWYEs7TbG2JOlS/vvEn2EfMmSJSpXrpz8/f3VqlUrtWnTRlOmTLnq/M8884zq1KmjKVOmKCkpSaGhoZKknTt36ty5c+rYsaPKlSvnnGbNmqVff/01z/X07t1bOTk5mjdvniRp3rx5MsaoV69ekqTDhw/rkUce0Y033qgKFSqoQoUKOnXqlPbu3VvgbfD9998rOTnZpe7OnTsrJydHu3fvVseOHVW9enXVrFlT/fv31+zZs3XmzJkCv19JRZbIklXIElmyClkiS1YhS/ZnyduyJbmh9u3ba9q0afLx8VFkZKR8fHyuOf/hw4f1008/ycvLS7t27VKXLl0kSTk5OZKkzz//XDfccIPLa/z8/PJcT4UKFdSzZ08lJSUpNjZWSUlJ6tmzp8qXLy/pwpXOf/75pyZPnqzq1avLz89PrVq1uurVxGXKXPg+ZS658OHyn5hycnI0bNgwjRgxItfrq1WrJl9fX23evFlr1qzRF198oVdeeUUJCQnauHGjgoOD87xuJR1ZIktWIUtkySpkiSxZhSzZn6US3ZAHBgYqOjo6z/MPHjxY9evX19ChQxUbG6s77rhDN910k2666Sb5+flp7969atu2baFqio2NVbt27bRkyRL9+9//1oQJE5zPrVu3TomJibrzzjslSfv27dORI0euuqywsDBJ0sGDB1WxYkVJFy5SuNQtt9yiHTt2XHM7eHt7q0OHDurQoYPi4+MVHBysVatWqUePHgVdzRKHLJElq5AlsmQVskSWrEKW7M9SiW7I82Pq1Kn65ptv9MMPP6hq1apatmyZ+vbtq2+//VZBQUGKi4vTU089pZycHP3lL39RWlqa1q9fr3LlymnAgAF5fp+2bdsqOjpaf/3rXxUdHe1yQUB0dLQ+/PBDNWvWTGlpaRo1apQCAgKuuqzo6GhVrVpVCQkJGjdunHbt2qW3337bZZ5nn31WLVu21GOPPaahQ4cqMDBQP/74o1auXKkpU6ZoyZIl+u2339SmTRtVrFhRS5cuVU5OjurUqZP/jQhJZIksWYcskSWrkCWyZBWyVERZKi0XdV7u0osUfvzxRxMQEGDmzJnjfP7kyZMmKirKjB492hhjTE5Ojvn73/9u6tSpY3x8fExYWJjp3LmzSUlJMcbk7SKFiyZMmGAkmQkTJriMb9682TRr1sz4+fmZ2rVrmwULFpjq1aubSZMmOefRJRcpGGPM119/bRo0aGD8/f1N69atzYIFC1wuUjDGmO+++8507NjRlCtXzgQGBpqGDRua8ePHG2MuXLDQtm1bU7FiRRMQEGAaNmxo5s2bd911KE3I0m7nPGSpcMjSbuc8ZKlwyNJu5zxkqXDI0m7nPMWdpUv7b8eQqUdd7rw+fXglazp9AAAAAFc0NPGY898l+i4rAAAAgLujIQcAAABsREMOAAAA2IiGHAAAALARDTkAAABgIxpyAAAAwEY05AAAAICNaMgBAAAAG9GQAwAAADaiIQcAAABsREMOAAAA2IiGHAAAALARDTkAAABgIxpyAAAAwEY05AAAAICNaMgBAAAAG9GQAwAAADaiIf8vh8ORp2nNmjV2l5rLmjVrrlnzI488YneJpYonZ+miI0eO6Mknn1RUVJT8/PwUHh6url276tixY3aXVqqUhCxddOjQIYWEhMjhcOjjjz+2u5xSx9OzFBUVxf7NTXh6loYMGaL69esrODhYAQEBuvHGGzVq1CgdOXLE1rq8bX13N/LNN9+4PB47dqxWr16tVatWuYzfdNNNxVlWntxyyy256pekadOmadasWbrvvvtsqKr08uQsSdKBAwfUunVreXt76+WXX1bt2rV15MgRrV69WpmZmXaXV6p4epYu9dhjj8nf39/uMkqtkpCl2267TW+99ZbLWHh4uE3VlF6enqXTp0/r4YcfVnR0tPz9/bVp0yaNHz9eS5cu1ZYtW+Tr62tLXTTk/9WyZUuXx2FhYSpTpkyu8cudOXNGZcuWLcrSrqt8+fK56jTGqG/fvqpevbo6duxoU2WlkydnSZKGDx+ujIwMbdq0SRUrVnSO9+jRw8aqSidPz9JFCxcu1IoVKzR16lQNGDDA7nJKpZKQpeDg4OvWi6Ln6Vn617/+5fL49ttvV1BQkIYPH66vv/5at99+uy11ccpKPrRr107169fX2rVrFRMTo7Jly2rw4MGSLvyEk5CQkOs1UVFRGjhwoMtYamqqhg0bpipVqsjX11c1atTQmDFjlJWVZVmtq1ev1m+//aZBgwapTBn+M7sbd83Snj179Omnn2ro0KEuzTjcl7tm6aJjx47pscce0/jx41WtWrVCLQtFy92zBM/haVkKCwuTJHl723ecmk4tnw4ePKh+/fqpT58+Wrp0qYYPH56v16empqp58+ZasWKFXnnlFS1btkyxsbGaOHGihg4d6jLvwIED5XA4tGfPnnzXOWPGDJUpU0aDBg3K92tRPNwxS+vWrZMxRpGRkerdu7fKlSsnf39/tWvX7oqnRcE9uGOWLhoxYoRq1Kihxx9/PF81wR7unKW1a9cqKChIPj4+uummm/T2228rOzs7X/Wh+LhzliQpKytLp0+f1r///W+9/PLL+stf/qLbbrstXzVaiVNW8unYsWNasGBBgX/SSEhI0PHjx7Vjxw7n0aI77rhDAQEBiouL06hRo5znXXl5ecnLy0sOhyNf73HixAl98skn6tixI0ek3Jg7ZumPP/6QJMXFxal9+/ZauHChTp8+rTFjxuj222/Xt99+q4YNGxaoXhQdd8ySJH3++eeaP3++Nm/ezC91HsJds3TXXXepWbNmqlWrlo4fP64FCxYoLi5OW7du1YcffligWlG03DVLkrRhwwa1atXK+fjOO+/U3Llz5eXlVaBarcAnZD5VrFixUOcXLVmyRO3bt1dkZKSysrKcU9euXSVJKSkpznlnzJihrKwsVa9ePV/vMXv2bJ07d05DhgwpcJ0oeu6YpZycHElSlSpVtHDhQnXu3Fk9evTQ8uXLVaZMGb3xxhsFrhdFxx2zdPLkSQ0bNkzPPvus6tevX+DaULzcMUuSNHXqVA0aNEht2rRR9+7d9dFHH+nxxx/XRx99pC1bthS4XhQdd82SJDVo0EAbN25USkqK/v73v2vLli3q2LGjzpw5U+B6C4sj5PlUuXLlQr3+0KFD+uyzz+Tj43PF56247c6MGTMUFham7t27F3pZKDrumKWQkBBJUocOHVyOFFSuXFmNGjXS5s2bC1YsipQ7ZunFF1+Uj4+PHn/8cZ04cUKSdOrUKUkXLu46ceKEKlSokO9fAFG03DFLV9OvXz+988472rBhg5o0aWLZcmENd85SYGCgmjVrJklq06aNWrRooZYtW+q9997TU089VeDlFgYNeT5dbefh5+enjIyMXONHjx51eRwaGqqGDRtq/PjxV1xOZGRkoerbsmWLtmzZomeeeeaqIYZ7cMcsXet0FGMMpx24KXfM0vbt27Vnzx5FRETkeu7inVaOHz+u4ODgfC8bRccds3Q1xhhJ4nPJTXlSlpo1a6YyZcro559/tmyZ+UVDbpGoqCj98MMPLmOrVq1yHhG6qFu3blq6dKlq1apVJHexmDFjhiQpNjbW8mWjeNiZpRYtWqhKlSr64osvlJ2d7TxKfuDAAf3nP/9Rnz59LHkfFA87szR58mTnkfGLtm7dqqeeekoJCQlq27atypUrZ8l7oei5yz7uUrNmzZKU+zZ8cG/umKWUlBTl5OQoOjq6SN/nWvhaaZH+/ftr2bJleuWVV/TVV19pypQpevTRR1WhQgWX+V599VX5+PgoJiZG06ZN06pVq7R06VIlJiaqW7du2r9/v3Pe2NhYeXt76/fff89TDefOndOcOXMUExOjevXqWbp+KD52ZqlMmTKaNGmSfvrpJ3Xv3t15UV7nzp3l6+ur559/vkjWGUXDziw1btxY7dq1c5kaN24sSbr55pvVrl07W28xhvyxM0tz5sxRz549lZSUpFWrVumTTz5R7969NW3aNA0cOFCNGjUqknVG0bAzS0uWLFH37t01Y8YMffnll1q2bJnGjh2rBx54QNHR0bZee8enoUVGjRqltLQ0JScn66233lLz5s01f/78XOdxV65cWZs2bdLYsWP15ptvav/+/QoKClKNGjXUpUsXl2+B2dnZys7Odv4sdz2ffPKJjh8/zsWcHs7uLPXs2VOLFi3S+PHj1bNnT/n5+alt27aaN2+eatWqZfn6oujYnSWUHHZmqWbNmjpx4oReeOEFHT16VD4+Prr55puVmJioYcOGFcn6oujYmaXo6Gj5+vpq7NixOnTokKQLR+xjY2P13HPP5fpSUJwcQ6Yedal++vBKdtUCAAAAlApDE485/80pKwAAAICNaMgBAAAAG9GQAwAAADaiIQcAAABsREMOAAAA2IiG/L+Sk5PlcDick7e3t6pUqaJBgwbpjz/+KJYaoqKiNHDgwAK9NiEhwaX+y6e5c+daWyyuytOzdNH27dv1wAMPKCwsTH5+foqKitLw4cOtKRB5UlKydNGXX37pXBcr/4Q6rs/Ts7Rnzx72b27C07O0b98+3XfffapZs6YCAwNVoUIFNWnSRO+8846ysrKsLTQfuA/5ZZKSklS3bl2dPXtWa9eu1cSJE5WSkqJt27YpMDDQ7vKuasiQIerSpUuu8aFDh+rXX3+94nMoWp6aJUlavXq17rrrLrVu3VrvvvuuQkNDtXfvXm3ZssXu0kolT87SRadOndLQoUMVGRmpAwcO2F1OqeXpWXriiSdy/cXg2rVr21RN6eapWTp9+rTKly+vl19+WdWqVVNmZqaWLl2qJ554Qlu3btU///lPW+qiIb9M/fr11axZM0lS+/btlZ2drbFjx2rx4sXq27fvFV9z5swZlS1btjjLzKVKlSqqUqWKy9iePXu0Y8cO9e3bV8HBwfYUVop5apbOnDmjvn376vbbb9dnn30mh8PhfK5///42VlZ6eWqWLvXcc8+pYsWKuuuuuzRu3Di7yym1PD1L1apVU8uWLe0uA/LcLNWtW1czZ850GevatasOHz6smTNnaurUqfLz8yv2ujhl5Tou/o9/8c+xDhw4UOXKldO2bdvUqVMnBQUF6Y477pAkZWZmaty4capbt678/PwUFhamQYMG6c8//3RZ5vnz5zV69GhFRESobNmy+stf/qLvvvvO8to/+OADGWP4y51uwlOytGDBAh08eFCjRo1yacbhPjwlSxetW7dO77//vv75z3/Ky8vLkmXCGp6WJbgvT89SWFiYypQpY9tnFEfIr+OXX36RdOE/1EWZmZm65557NGzYMD333HPKyspSTk6OunfvrnXr1mn06NGKiYnR77//rvj4eLVr106bNm1SQECApAunkcyaNUtxcXHq2LGjtm/frh49eig9PT3X+0dFRUm6cLQ7P3JycpScnKzo6Gi1bdu2YCsPS3lKltauXSvpwp8ivvjhFxgYqC5duujtt99WZGSkBVsDheEpWZKks2fPKjY2ViNHjtQtt9yiTz/9tPAbAJbxpCxJ0muvvaYXXnhB3t7euuWWWzR69Gjdc889hdsIsISnZckYo+zsbKWnp+uLL75QcnKynnnmGXl729QaD5l61Fw6lVZJSUlGktmwYYM5f/68SU9PN0uWLDFhYWEmKCjIpKamGmOMGTBggJFkPvjgA5fX/+tf/zKSzMKFC13GN27caCSZxMREY4wxP/74o5FknnrqKZf5Zs+ebSSZAQMGuIzXqlXL1KpVK9/rs2zZMiPJTJw4Md+vReF4epY6d+5sJJng4GAzevRos2rVKvPuu++akJAQEx0dbU6fPp3fTYIC8vQsGWPMM888Y2rWrGnOnDljjDEmPj7eSDJ//vlnnrcDCs/Ts3TgwAEzdOhQM3/+fLNu3Toze/Zs07JlSyPJTJ8+Pb+bA4Xg6Vm6aOLEiUaSkWQcDod58cUX8/xaq1zaf9OQ/9fFgF0+NWjQwHz99dfO+S4G7OTJky6v79u3rwkODjaZmZnm/PnzLlNERIR58MEHjTHGJCYmGklm06ZNLq8/f/688fb2zhWwgurZs6fx9vY2Bw8etGR5yDtPz1LHjh2NJDNs2DCX8cWLF7PzK2aenqVvv/3WeHl5mZUrVzrHaMjt4elZupLMzEzTpEkTExISYs6fP2/ZcnFtJSVLBw8eNBs3bjQrVqwwzz77rPH19TWPP/54oZaZX5f235yycplZs2apXr168vb2Vnh4uCpXrpxrnrJly6p8+fIuY4cOHdKJEyfk6+t7xeVevMXX0aNHJUkREREuz3t7eyskJMSKVdCRI0f06aef6q677sr1Pig+npqli6/t3Lmzy3jnzp3lcDi0efPmAi8bBeOpWRo8eLB69OihZs2a6cSJE5Kkc+fOSZLS0tLk5+enoKCgAi8f+eepWboSHx8fPfTQQ3ruuee0a9cu1atXz9Ll49o8PUsRERHOZXfq1EkVK1bUc889p8GDB6tJkyaFXn5+0ZBfpl69es6rhq/mShe6hYaGKiQkRMuXL7/iay7udC6GKDU1VTfccIPz+aysLGf4CuvDDz9UZmYmF3PazFOz1LBhw2ve17dMGa4FL26emqUdO3Zox44dWrBgQa7natWqpUaNGmnr1q0FXj7yz1OzdDXGGEl8LtmhpGWpefPmkqSff/6ZhtyTdevWTXPnzlV2drZatGhx1fnatWsnSZo9e7aaNm3qHJ8/f75lN6SfMWOGIiMj1bVrV0uWh+Jld5buu+8+vfjii1q2bJnuu+8+5/iyZctkjOGWYx7E7iytXr0611hycrJmzpypxYsXu+xk4d7sztKVnD9/XvPmzVNoaKiio6MtXTaKjjtmSfq/zyu7skRDbpFevXpp9uzZuvPOO/Xkk0+qefPm8vHx0f79+7V69Wp1795d9913n+rVq6d+/fpp8uTJ8vHxUYcOHbR9+3a99dZbuX7Wkf4vGBevXr6eb7/9Vjt27NALL7zA7cU8lN1Zqlu3rh577DElJiYqKChIXbt21c8//6yXXnpJTZo00YMPPlgk6w3r2Z2lizvUS61Zs0aSdNtttyk0NLTQ64jiYXeWnn76aZ0/f1633XabIiIitG/fPk2ZMkVbt25VUlIS+zsPYneW4uPjdejQIbVp00Y33HCDTpw4oeXLl2v69Ol64IEHXJr/4kRDbhEvLy99+umn+vvf/64PP/xQEydOdP452bZt26pBgwbOeWfMmKHw8HAlJyfrH//4hxo3bqyFCxeqV69euZab32+BM2bMkMPhUGxsbKHXCfZwhyxNnjxZVapU0T//+U9NmTJFoaGh6tWrlyZMmHDV8/7gftwhSygZ7M5S/fr19d5772nOnDlKS0tTUFCQmjdvrhUrVqhTp06WrSeKnt1Zatasmf7xj39o8eLFOnr0qPz9/XXTTTdp0qRJevTRRy1bz/xyDJl61Fw6MH14JbtqAQAAAEqFoYnHnP/mKggAAADARjTkAAAAgI1oyAEAAAAb0ZADAAAANqIhBwAAAGzkkbc9zMnJ0YEDBxQUFHTFvwKFomWMUXp6uiIjIz3+r6ORJXuRJViFLMEqJSlLEnmyU36y5JEN+YEDB1S1alW7yyj19u3bpypVqthdRqGQJfdAlmAVsgSrlIQsSeTJHeQlSx7ZkAcFBUm6sIJX+mtNKFppaWmqWrWq87+DJyNL9iJLsApZglVKUpYk8mSn/GTJIxvyiz+5lC9fnnDZqCT89EWW3ANZglXIEqxSErIkkSd3kJcsef7JUQAAAIAH88gj5FczZsyYaz4fHx9fTJXA05ElWIUswSrXy5JEnpA3ZMn9cIQcAAAAsBENOQAAAGAjGnIAAADARjTkAAAAgI1oyAEAAAAb0ZADAAAANqIhBwAAAGxEQw4AAADYiIYcAAAAsBENOQAAAGAjGnIAAADARjTkAABISkhIkMPhcJkiIiKczxtjlJCQoMjISAUEBKhdu3basWOHyzIyMjL0xBNPKDQ0VIGBgbrnnnu0f//+4l4VAB6Ghhy2YMcHwB3dfPPNOnjwoHPatm2b87k33nhDf/vb3/TOO+9o48aNioiIUMeOHZWenu6cZ+TIkVq0aJHmzp2rr7/+WqdOnVK3bt2UnZ1tx+oA8BD5ashpomAldnywAp9LsJK3t7ciIiKcU1hYmKQLOZo8ebJefPFF9ejRQ/Xr19fMmTN15swZzZkzR5J08uRJzZgxQ2+//bY6dOigJk2a6KOPPtK2bdv05ZdfXvU9MzIylJaW5jIBKF3yfYScJgpWYccHq/C5BKvs2rVLkZGRqlGjhnr16qXffvtNkrR7926lpqaqU6dOznn9/PzUtm1brV+/XpL0/fff6/z58y7zREZGqn79+s55rmTixImqUKGCc6patWoRrR0Ad5XvhpwmClZhxwer2PG5hJKnRYsWmjVrllasWKHp06crNTVVMTExOnr0qFJTUyVJ4eHhLq8JDw93PpeamipfX19VrFjxqvNcyfPPP6+TJ086p3379lm8ZgDcXb4bcpooWIEdH6xkx+cSBwpKnq5du+r+++9XgwYN1KFDB33++eeSpJkzZzrncTgcLq8xxuQau9z15vHz81P58uVdJgClS74acpooWIUdH6xi1+cSBwpKvsDAQDVo0EC7du1yXpdweSYOHz7szFdERIQyMzN1/Pjxq86D0oFrW5Bf+WrIaaJQVNjxoaDs+lziQEHJl5GRoR9//FGVK1dWjRo1FBERoZUrVzqfz8zMVEpKimJiYiRJTZs2lY+Pj8s8Bw8e1Pbt253zoPTg2hbkR6Fue0gTBauw44NViutziQMFJU9cXJxSUlK0e/duffvtt+rZs6fS0tI0YMAAORwOjRw5UhMmTNCiRYu0fft2DRw4UGXLllWfPn0kSRUqVFBsbKyeeeYZffXVV9qyZYv69evn/LKI0oVrW5AfhWrIaaJQUOz4UFT4XEJB7d+/X71791adOnXUo0cP+fr6asOGDapevbokafTo0Ro5cqSGDx+uZs2a6Y8//tAXX3yhoKAg5zImTZqke++9Vw8++KBuu+02lS1bVp999pm8vLzsWi3YxI5rWySub/FU3vmZOS4uTnfffbeqVaumw4cPa9y4cVdsomrXrq3atWtrwoQJV22iQkJCVKlSJcXFxdFElUIXd3xHjhxRWFiYWrZsmWvHd/bsWQ0fPlzHjx9XixYtrrjj8/b21oMPPqizZ8/qjjvuUHJyMju+UobPJVhl7ty513ze4XAoISFBCQkJV53H399fU6ZM0ZQpUyyuDp7k4rUtN954ow4dOqRx48YpJiZGO3bsuOa1Lb///rukgl/bIl24vmXMmDEWrg2KQ74acpooWIUdH6zC5xIAd9O1a1fnvxs0aKBWrVqpVq1amjlzplq2bCmpaK5tkS5c3/L00087H6elpXHRuQfIV0NOEwXA3fC5BMDdXXpty7333ivpwlHwypUrO+e52rUtlx4lP3z48HVPpfPz85Ofn5/1K4EiVahzyAEAAHBtXNuC68nXEXIAAABcG9e2IL9oyAEAACzEtS3ILxpyAAAAC3FtC/KLc8gBAAAAG9GQAwAAADaiIQcAAABsREMOAAAA2IiGHAAAALARDTkAAABgIxpyAAAAwEY05AAAAICNaMgBAAAAG9GQAwAAADaiIQcAAABsREMOAAAA2IiGHAAAALARDTkAAABgIxpyAAAAwEY05AAAAICNaMgBAAAAG9GQAwAAADaiIQcAAABsREMOAAAA2IiGHAAAALARDTkAAABgIxpyAAAAwEY05AAAAICNaMgBAAAAG9GQAwAAADaiIQcAAABsREMOAAAA2IiGHAAAALARDTkAAABgIxpyAAAAwEY05AAAAICNaMgBAAAAG9GQAwAAADaiIQcAAABsREMOAAAA2IiGHAAAALCRt90FAAAAFJUxY8Zcd574+PhiqAS4OhpyAJZhxwcAQP5xygoAAABgI46Q47pHNTmiCQAAUHRoyAEAbofTnwCUJpyyAgAAANiIhhwAAACwEQ05AAAAYCNbG/LExETVqFFD/v7+atq0qdatW2dnOfBgZAlWIUuwClmCVchSyWfbRZ3z5s3TyJEjlZiYqNtuu03vvfeeunbtqp07d6patWp2lQUPRJZgFbJ0fVxsmTdkCVYhS+6jKD//bGvI//a3vyk2NlZDhgyRJE2ePFkrVqzQtGnTNHHixCJ5z7xsyLzwpJ2NVevsztw1S56Uk7wgS0WTJZRMdmWppN3GtjR87lyPu34uudN+0J1qKShbGvLMzEx9//33eu6551zGO3XqpPXr1+eaPyMjQxkZGc7HJ0+elCSlpaW5zHfu3LkiqDa3559/vljex11cvp0vPjbG2FGOC3fOUmnLSV5duq3JUsnMSXGtE1nKv5KYNyuUlCxJectTXrJkRVbcKW/u/rlkS0N+5MgRZWdnKzw83GU8PDxcqampueafOHHiFb/9VK1atchqxP957bXXrjienp6uChUqFHM1rsiS57lSnsgSCoIswSolJUsSebJbQbNk6x8GcjgcLo+NMbnGpAvfap5++mnn45ycHB07dkwhISHO+dPS0lS1alXt27dP5cuXL9rCS5ErbVdjjNLT0xUZGWlzdf+HLHmGy7ctWUJBkSVYpSRlSbp+nshS0SlMlmxpyENDQ+Xl5ZXr293hw4dzfQuUJD8/P/n5+bmMBQcHX3HZ5cuXJ2BF4PLtavdRg4vIkme6dNuSJRQGWYJVSkKWpLzniSwVnYJkyZbbHvr6+qpp06ZauXKly/jKlSsVExNjR0nwUGQJViFLsApZglXIUulh2ykrTz/9tPr3769mzZqpVatWev/997V371498sgjdpUED0WWYBWyBKuQJViFLJUOtjXkDz30kI4ePapXX31VBw8eVP369bV06VJVr169QMvz8/NTfHx8rp9pUDiesF3Jkudw921LljyHu29bsuQ53H3bkiXPUZht6xgy9ajLvVimD69kWWEAAAAAchuaeMz5b1vOIQcAAABwAQ05AAAAYCMacgAAAMBGNOQAAACAjUpEQ56YmKgaNWrI399fTZs21bp16+wuqURYu3at7r77bkVGRsrhcGjx4sV2l1TkyFLRIEtkySpkiSxZhSyRJatYkSWPb8jnzZunkSNH6sUXX9SWLVvUunVrde3aVXv37rW7NI93+vRpNWrUSO+8847dpRQLslR0yBJZsgpZIktWIUtkySpWZMnjb3vYokUL3XLLLZo2bZpzrF69err33ns1ceJEGysrWRwOhxYtWqR7773X7lKKDFkqHmSJLFmFLJElq5AlsmSV/GSpxNz2MDMzU99//706derkMt6pUyetX7/epqrgicgSrEKWYBWyBKuQJffn0Q35kSNHlJ2drfDwcJfx8PBwpaam2lQVPBFZglXIEqxClmAVsuT+PLohv8jhcLg8NsbkGgPygizBKmQJViFLsApZcl8e3ZCHhobKy8sr17e7w4cP5/oWCFwLWYJVyBKsQpZgFbLk/jy6Iff19VXTpk21cuVKl/GVK1cqJibGpqrgicgSrEKWYBWyBKuQJffnbXcBhfX000+rf//+atasmVq1aqX3339fe/fu1SOPPGJ3aR7v1KlT+uWXX5yPd+/era1bt6pSpUqqVq2ajZUVDbJUdMgSWbIKWSJLViFLZMkqlmRpyNSj5tLJE02dOtVUr17d+Pr6mltuucWkpKTYXVKJsHr1aiMp1zRgwAC7SysyZKlokCWyZBWyRJasQpbIklUKmqVL+2+Pvw85AAAA4GlKzH3IAQAAAE9HQw4AAADYiIYcAAAAsBENOQAAAGAjGnIAAADARjTkAAAAgI1oyAEAAAAb0ZADAAAANiq1DXlCQoIaN25s2fLWrFkjh8OhEydOWLbMK3E4HFq8eHGRvgfyhyzBKmQJViFLsApZKh4ltiEfOHCgHA6HHA6HfHx8VLNmTcXFxen06dOSpLi4OH311VfFUktmZqZCQ0M1bty4Kz4/ceJEhYaGKjMzs1jqQf6QJViFLMEqZAlWIUvuocQ25JLUpUsXHTx4UL/99pvGjRunxMRExcXFSZLKlSunkJCQYqnD19dX/fr1U3JysowxuZ5PSkpS//795evrWyz1IP/IEqxClmAVsgSrkCX7leiG3M/PTxEREapatar69Omjvn37On++uPQnmHPnzunmm2/Www8/7Hzt7t27VaFCBU2fPl2SZIzRG2+8oZo1ayogIECNGjXSxx9/nOdaYmNj9euvv2rt2rUu4+vWrdOuXbsUGxurjRs3qmPHjgoNDVWFChXUtm1bbd68+arLvNLPPlu3bpXD4dCePXucY+vXr1ebNm0UEBCgqlWrasSIEc5vvpKUmJio2rVry9/fX+Hh4erZs2ee16u0IEsXkKXCI0sXkKXCI0sXkKXCI0sX2JmlEt2QXy4gIEDnz5/PNe7v76/Zs2dr5syZWrx4sbKzs9W/f3+1b99eQ4cOlSS99NJLSkpK0rRp07Rjxw499dRT6tevn1JSUvL03g0aNNCtt96qpKQkl/EPPvhAzZs3V/369ZWenq4BAwZo3bp12rBhg2rXrq0777xT6enpBV7nbdu2qXPnzurRo4d++OEHzZs3T19//bUef/xxSdKmTZs0YsQIvfrqq/rpp5+0fPlytWnTpsDvV1qQJbJkFbJElqxClsiSVciSDVkaMvWouXQqKQYMGGC6d+/ufPztt9+akJAQ8+CDDxpjjImPjzeNGjVyec0bb7xhQkNDzRNPPGEiIiLMn3/+aYwx5tSpU8bf39+sX7/eZf7Y2FjTu3dvY4wxq1evNpLM8ePHr1rTtGnTTGBgoElPTzfGGJOenm4CAwPNe++9d8X5s7KyTFBQkPnss8+cY5LMokWLrvqeW7ZsMZLM7t27jTHG9O/f3zz88MMuy123bp0pU6aMOXv2rFm4cKEpX768SUtLu2rdpR1Z2m2MIUtWIEu7jTFkyQpkabcxhixZgSztNsbYk6VL++8SfYR8yZIlKleunPz9/dWqVSu1adNGU6ZMuer8zzzzjOrUqaMpU6YoKSlJoaGhkqSdO3fq3Llz6tixo8qVK+ecZs2apV9//TXP9fTu3Vs5OTmaN2+eJGnevHkyxqhXr16SpMOHD+uRRx7RjTfeqAoVKqhChQo6deqU9u7dW+Bt8P333ys5Odml7s6dOysnJ0e7d+9Wx44dVb16ddWsWVP9+/fX7NmzdebMmQK/X0lFlsiSVcgSWbIKWSJLViFL9mfJ27IluaH27dtr2rRp8vHxUWRkpHx8fK45/+HDh/XTTz/Jy8tLu3btUpcuXSRJOTk5kqTPP/9cN9xwg8tr/Pz88lxPhQoV1LNnTyUlJSk2NlZJSUnq2bOnypcvL+nClc5//vmnJk+erOrVq8vPz0+tWrW66tXEZcpc+D5lLrnw4fKfmHJycjRs2DCNGDEi1+urVasmX19fbd68WWvWrNEXX3yhV155RQkJCdq4caOCg4PzvG4lHVkiS1YhS2TJKmSJLFmFLNmfpRLdkAcGBio6OjrP8w8ePFj169fX0KFDFRsbqzvuuEM33XSTbrrpJvn5+Wnv3r1q27ZtoWqKjY1Vu3bttGTJEv373//WhAkTnM+tW7dOiYmJuvPOOyVJ+/bt05EjR666rLCwMEnSwYMHVbFiRUkXLlK41C233KIdO3Zcczt4e3urQ4cO6tChg+Lj4xUcHKxVq1apR48eBV3NEocskSWrkCWyZBWyRJasQpbsz1KJbsjzY+rUqfrmm2/0ww8/qGrVqlq2bJn69u2rb7/9VkFBQYqLi9NTTz2lnJwc/eUvf1FaWprWr1+vcuXKacCAAXl+n7Zt2yo6Olp//etfFR0d7XJBQHR0tD788EM1a9ZMaWlpGjVqlAICAq66rOjoaFWtWlUJCQkaN26cdu3apbfffttlnmeffVYtW7bUY489pqFDhyowMFA//vijVq5cqSlTpmjJkiX67bff1KZNG1WsWFFLly5VTk6O6tSpk/+NCElkiSxZhyyRJauQJbJkFbJURFkqLRd1Xu7SixR+/PFHExAQYObMmeN8/uTJkyYqKsqMHj3aGGNMTk6O+fvf/27q1KljfHx8TFhYmOncubNJSUkxxuTtIoWLJkyYYCSZCRMmuIxv3rzZNGvWzPj5+ZnatWubBQsWmOrVq5tJkyY559ElFykYY8zXX39tGjRoYPz9/U3r1q3NggULXC5SMMaY7777znTs2NGUK1fOBAYGmoYNG5rx48cbYy5csNC2bVtTsWJFExAQYBo2bGjmzZt33XUoTcjSbuc8ZKlwyNJu5zxkqXDI0m7nPGSpcMjSbuc8xZ2lS/tvx5CpR13uvD59eCVrOn0AAAAAVzQ08Zjz3yX6LisAAACAu6MhBwAAAGxEQw4AAADYiIYcAAAAsBENOQAAAGAjGnIAAADARjTkAAAAgI1oyAEAAAAbeV8+cOlNygEAAAAULY6QAwAAADaiIQcAAABs9P8BqVWGUuKhHU4AAAAASUVORK5CYII=",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "if n_tp > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix']\\\n",
+ " + \"_\" + \"ExampleImages_TruePostives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayHistogramConfusion(x_tra_tp,\n",
+ " y_true_tra_tp,\n",
+ " y_choice_tra_tp,\n",
+ " title_main=\"True Positives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_fp > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix']\\\n",
+ " + \"_\" + \"ExampleImages_FalsePostives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayHistogramConfusion(x_tra_fp,\n",
+ " y_true_tra_fp,\n",
+ " y_choice_tra_fp,\n",
+ " title_main=\"False Positives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_tn > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix']\\\n",
+ " + \"_\" + \"ExampleImages_TrueNegatives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayHistogramConfusion(x_tra_tn,\n",
+ " y_true_tra_tn,\n",
+ " y_choice_tra_tn,\n",
+ " title_main=\"True Negatives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_fn > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix']\\\n",
+ " + \"_\" + \"ExampleImages_FalseNegatives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayHistogramConfusion(x_tra_fn,\n",
+ " y_true_tra_fn,\n",
+ " y_choice_tra_fn,\n",
+ " title_main=\"False Negatives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 9: Histograms of the pixel flux values for the images shown in Figure 7. Here, it is difficult to infer reasons for network classification errors like false positives and false negatives. If it is expected that a particular would have a particular distribution of pixel brightnesses, but the histogram for the predicted digit has a different distribution, that may help you identify a pattern. This may be a more useful diagnostic for physics-related data, like galaxy light profiles."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 4.3.4.3. Investigate morphological features of images with feature maps"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The last diagnostic of this tutorial is the `feature map`, which shows how that one layer affects or processes the input image after the weights have been set during the model training. The morphological elements in the feature maps are \n",
+ "\n",
+ "The feature map is obtained by passing an input image through one layer of the trained neural network model. More specifically, we first choose an input datum (image) by setting the array index of interest. Then, we make the image a `tensor` with the appropriate number of dimensions using `unsqueeze`. Then, we transfer it to the `device`. Next, we choose the `conv1` layer defined the section above where we defined the `model`. Finally, we use the number of convolutional kernels (32) that exist in the `conv1` layer of the model. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 51,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:50.809088Z",
+ "iopub.status.busy": "2025-02-09T23:13:50.808737Z",
+ "iopub.status.idle": "2025-02-09T23:13:55.192290Z",
+ "shell.execute_reply": "2025-02-09T23:13:55.191710Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:50.809054Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "index_input_image = 1\n",
+ "number_conv_kernels = 32\n",
+ "\n",
+ "input_image = dataset.data[index_input_image].type(torch.float32)\n",
+ "input_image = input_image.clone().detach()\n",
+ "input_image = input_image.unsqueeze(0)\n",
+ "input_image = input_image.to(device)\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " feature_maps = model.conv1(input_image).cpu()\n",
+ "\n",
+ "fig, ax = plt.subplots(4, 8, sharex=True, sharey=True, figsize=(16, 8))\n",
+ "\n",
+ "for i in range(0, number_conv_kernels):\n",
+ " row, col = i//8, i%8\n",
+ " ax[row][col].imshow(feature_maps[i])\n",
+ "\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 10: `Feature maps` of one input from the `conv1` layer of the `model`. The units on the x- and y-axes are the pixel indices. There are 32 images because there are 32 convolutional kernels in the `conv1` layer.\n",
+ ">\n",
+ "> The feature maps derived from the trained model show which morphological features --- e.g., lines and edges --- are favored by the model. When the model is accurate, these features and feature maps will accurately reflect the input image. In this example, the handwritten digit is \"0,\" and the feature maps all appear as \"0's\". Some maps have more clearly defined edges and \"0\"-like features than others, because each convolutional kernel has a distinct weight parameter associated with it. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 5. Exercises for the Learner"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Each time you train a new model, re-run all the diagnostic plots.\n",
+ "\n",
+ "1. How do the loss and accuracy histories change when batch size is small or large? Why?\n",
+ "2. Does the NN take more or less time (more or fewer epochs) to converge if the input image data are normalized or not normalized? Why?\n",
+ "3. How does the size of the training set affect the model's accuracy and loss -- keeping the number of epochs the same? Why?\n",
+ "3. How does the random seed for the weight initialization affect the model's accuracy and loss -- keeping the number of epochs the same?\n",
+ "5. Use the `time` module to estimate the time for the model fitting. Record that time. Increase and then decrease the number of weights in the NN by an order of magnitude. Train the NN for each of those models and record the times. How does the number of weights in the neural network affect the training time and model loss and accuracy?\n",
+ "6. Use the `time` module to estimate the time for the model fitting. Record that time. Increase and then decrease the number of layers in the NN. Train the NN for each of those models and record the times. How does the number of weights in the neural network affect the training time and model loss and accuracy?\n",
+ "7. Use the `time` module to estimate the time for the model fitting. Record that time. Add a convolutional layer to the NN. Train the NN for each of those models and record the times. How does the number of weights in the neural network affect the training time and model loss and accuracy?"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "machine_shape": "hm",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
From 291451e4d4eb88973069d659cf13054e5fcff699 Mon Sep 17 00:00:00 2001
From: Brian Nord <184985+bnord@users.noreply.github.com>
Date: Tue, 8 Apr 2025 08:38:45 -0500
Subject: [PATCH 5/5] updated notebook after review
renamed from "AI0_Intro_AI_ImageClassificationWithTensorflow_Draft.ipynb"
Addressed Lau comments
Added more section-by-section outlines
Added the ROC Curve Diagnostic
---
DP02_16a_Introduction_to_AI.ipynb | 2672 +++++++++++++++++++++++++++++
1 file changed, 2672 insertions(+)
create mode 100644 DP02_16a_Introduction_to_AI.ipynb
diff --git a/DP02_16a_Introduction_to_AI.ipynb b/DP02_16a_Introduction_to_AI.ipynb
new file mode 100644
index 0000000..1f45658
--- /dev/null
+++ b/DP02_16a_Introduction_to_AI.ipynb
@@ -0,0 +1,2672 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " \n",
+ " Introduction to AI-based Image Classification with Pytorch \n",
+ "Contact author: Brian Nord \n",
+ "Last verified to run: 2025-03-25 \n",
+ "LSST Science Pipelines version: Weekly 2025_09 \n",
+ "Container size: medium \n",
+ "Targeted learning level: beginner "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Description:** An introduction to the classification of images with AI-based classification algorithms."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Skills:** Examine AI training data, prepare it for a classification task, perform classification with a neural network, and examine the diagnostics of the classification task."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**LSST Data Products:** None; MNIST data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Packages:** numpy, matplotlib, sklearn, pytorch"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Credits and Acknowledgments:** We thank Ryan Lau and Melissa Graham for feedback on the notebook."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Get Support:**\n",
+ "Find DP0-related documentation and resources at dp0.lsst.io. Questions are welcome as new topics in the Support - Data Preview 0 Category of the Rubin Community Forum. Rubin staff will respond to all questions posted there."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 1. Introduction\n",
+ "\n",
+ "This Jupyter Notebook introduces artificial intelligence (AI)-based image classification. It demonstrates how to perform a few key steps:\n",
+ "1. examine and prepare data for classification;\n",
+ "2. train an AI algorithm;\n",
+ "3. plot diagnostics of the training performance;\n",
+ "4. initially assess those diagnostics.\n",
+ "\n",
+ "\n",
+ "In this section, you will find:\n",
+ "1. a general definition of AI;\n",
+ "3. commentary on terminology in AI;\n",
+ "4. a description of the software and data used in this tutorial;\n",
+ "4. package imports;\n",
+ "5. function definitions;\n",
+ "6. definition of paths for data and figures."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.1. Definition of AI\n",
+ "\n",
+ "AI is a class of algorithms for building statistical models. These algorithms primarily use data for training, as opposed to models that use analytic formulae or models that are based on physical reasoning. Machine learning is a subclass of algorithms -- e.g., random forests. Deep learning is a subclass of algorithms -- e.g., neural networks. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.2. AI is math, not magic.\n",
+ "\n",
+ "AI is firmly based in math, computer science, and statistics. Additionally, some of the approaches are inspired by concepts or notions in biology (e.g., the computational neuron) and in physics (e.g., the Reverse Boltzmann Machine). \n",
+ "\n",
+ "Much of the jargon in AI is anthropomorphic, which can make it appear that some other than math is happening. For example, consider the following list of terms that are very often used in AI -- and what these terms actually mean mathematically.\n",
+ "\n",
+ "1. `learn` $\\rightarrow$ fit\n",
+ "2. `hallucinate`/`lie` $\\rightarrow$ predict incorrectly\n",
+ "3. `understand` $\\rightarrow$ model fit has converged\n",
+ "4. `cheat` $\\rightarrow$ more efficiently guesses the best weight parameters of the model\n",
+ "5. `believe` $\\rightarrow$ predict/infer based on statistical priors\n",
+ "\n",
+ "When we over-anthropomorphize these mathematical concepts, we obfuscate how they actually work. That makes it harder to build and refine models. That is, AI models are not 'learning' or 'understanding'; they are merely large-parameter models that are being fit to data, and AI includes novel methods for that fitting process. The only learning that's happening is what humans do with these models.\n",
+ "\n",
+ "Many of the most useful AI-related terms are defined throughout this tutorial."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.3. Software and Data Description\n",
+ "\n",
+ "In this notebook, we use `pytorch`, which is currently the library most often used in deep learning studies. `pytorch` is among the state-of-the-art python libraries for tensor manipulation in general and neural network model development in particular. \n",
+ "\n",
+ "Instead of using DP0 data, this tutorial uses [handwritten digits AI benchmarking data from MNIST (Modified National Institute of Standards and Technology )](https://en.wikipedia.org/wiki/MNIST_database), a large database of handwritten digits that is commonly used for training and testing machine learning algorithms. The [`MNIST handwritten digits dataset`](https://ieeexplore.ieee.org/document/6296535) comprises 10 classes --- one for each digit --- where each image is a picture of the digit written by hand. This is a useful dataset for learning the basics of neural networks and other AI algorithms. MNIST is one of a few canonical AI benchmark data sets for image classification. \n",
+ "\n",
+ "Later tutorials in this series will use stars and galaxies drawn from DP0 data."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.4. Import packages\n",
+ "\n",
+ "[`numpy`](https://numpy.org/) is used for computations and mathematical operations on multi-dimensional arrays.\n",
+ "\n",
+ "[`pandas`](https://pandas.pydata.org/) is used to organize and manage data.\n",
+ "\n",
+ "[`matplotlib`](https://matplotlib.org/) is a plotting library. \n",
+ "\n",
+ "[`seaborn`](https://seaborn.pydata.org/) is used for visualizations.\n",
+ "\n",
+ "[`sklearn`](https://scikit-learn.org/stable/) is a library for machine learning.\n",
+ "\n",
+ "[`torch`](https://www.pytorch.org) is used for fast tensor operations --- often used for building neural network models."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "JyaaGkFE8VOl"
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import time\n",
+ "import datetime\n",
+ "import os\n",
+ "\n",
+ "import matplotlib\n",
+ "import matplotlib.pyplot as plt\n",
+ "from matplotlib.pyplot import cm\n",
+ "from matplotlib.colors import LogNorm\n",
+ "import seaborn as sns\n",
+ "\n",
+ "from sklearn.metrics import confusion_matrix, RocCurveDisplay\n",
+ "from sklearn.preprocessing import LabelBinarizer\n",
+ "\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.optim as optim\n",
+ "import torch.nn.functional as F\n",
+ "import torchvision\n",
+ "from torch.utils.data import Dataset, DataLoader, Subset, random_split"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.5. Define functions"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The following functions are defined and used throughout this notebook. It is not necessary to understand exactly what every function does to proceed with this tutorial. Execute all cells and move on to Section 2."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def normalizeInputs(x_temp, input_minimum, input_maximum):\n",
+ " \"\"\"Normalize a datum that is an input to the neural network\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " x_temp: `numpy.array`\n",
+ " image data\n",
+ " input_minimum: `float`\n",
+ " minimum value for normalization\n",
+ " input_maximum: `float`\n",
+ " maximum value for normalization\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " x_temp_norm: `numpy.array`\n",
+ " normalized image data\n",
+ " \"\"\"\n",
+ " x_temp_norm = (x_temp - input_minimum)/input_maximum\n",
+ " return x_temp_norm"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def createFileUidTimestamp():\n",
+ " \"\"\"Create a timestamp for a filename.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " None\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " file_uid_timestamp : `string`\n",
+ " String from date and time.\n",
+ " \"\"\"\n",
+ " file_uid_timestamp = datetime.datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
+ " return file_uid_timestamp"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def createFileName(file_prefix=\"\", file_location=\"Data/Sandbox/\",\n",
+ " file_suffix=\"\", useuid=True, verbose=True):\n",
+ " \"\"\"Create a file name.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " file_prefix: `string`\n",
+ " prefix of file name\n",
+ " file_location: `string`\n",
+ " path to file\n",
+ " file_suffix: `string`\n",
+ " suffix/extension of file name\n",
+ " useuid: 'bool'\n",
+ " choose to use a unique id\n",
+ " verbose: 'bool'\n",
+ " choose to print the file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " file_final: `string`\n",
+ " filename used for saving\n",
+ " \"\"\"\n",
+ " if useuid:\n",
+ " file_uid = createFileUidTimestamp()\n",
+ " else:\n",
+ " file_uid = \"\"\n",
+ "\n",
+ " file_final = file_location + file_prefix + \"_\" + file_uid + file_suffix\n",
+ "\n",
+ " if verbose:\n",
+ " print(file_final)\n",
+ "\n",
+ " return file_final"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plotArrayImageExamples(subset_train,\n",
+ " num_row=3, num_col=3,\n",
+ " object_index_start=0,\n",
+ " figsize=(7, 7),\n",
+ " save_file=False,\n",
+ " file_prefix=\"ImageExamples\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot an array of examples of images and labels\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " subset_train: `numpy.ndarray`\n",
+ " training data images\n",
+ " num_row: `int`, optional\n",
+ " number of rows to plot\n",
+ " num_col: `int`, optional\n",
+ " number of columns to plot\n",
+ " figsize: `tuple`, optional\n",
+ " size of figure\n",
+ " object_index_start: `int`, optional\n",
+ " starting index for set of images to plot\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " From: https://pytorch.org/tutorials/beginner/basics/data_tutorial.html\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " labels_map = {\n",
+ " 0: \"0\",\n",
+ " 1: \"1\",\n",
+ " 2: \"2\",\n",
+ " 3: \"3\",\n",
+ " 4: \"4\",\n",
+ " 5: \"5\",\n",
+ " 6: \"6\",\n",
+ " 7: \"7\",\n",
+ " 8: \"8\",\n",
+ " 9: \"9\",\n",
+ " }\n",
+ "\n",
+ " figure = plt.figure(figsize=figsize)\n",
+ "\n",
+ " for i in range(0, num_row * num_col):\n",
+ " sample_idx = object_index_start + i\n",
+ " img, label = subset_train[sample_idx]\n",
+ " figure.add_subplot(num_row, num_col, i + 1)\n",
+ " plt.title(\"label (digit): \" + labels_map[label])\n",
+ " plt.axis(\"off\")\n",
+ " plt.imshow(img.squeeze(), cmap=\"gray\")\n",
+ "\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plotArrayHistogramExamples(subset_train,\n",
+ " num_row=3, num_col=3,\n",
+ " n_bins=10,\n",
+ " object_index_start=0,\n",
+ " figsize=(7, 7),\n",
+ " save_file=False,\n",
+ " file_prefix=\"HistogramExamples\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot histograms of image pixel values\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " subset_train: `numpy.ndarray`\n",
+ " training data images\n",
+ " num_row: `int`, optional\n",
+ " number of rows to plot\n",
+ " num_col: `int`, optional\n",
+ " number of columns to plot\n",
+ " n_bins: `int`, optional\n",
+ " number of bins in histogram \n",
+ " object_index_start: `int`, optional\n",
+ " starting index for set of images to plot\n",
+ " figsize: `tuple`, optional\n",
+ " size of figure\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " labels_map = {\n",
+ " 0: \"0\",\n",
+ " 1: \"1\",\n",
+ " 2: \"2\",\n",
+ " 3: \"3\",\n",
+ " 4: \"4\",\n",
+ " 5: \"5\",\n",
+ " 6: \"6\",\n",
+ " 7: \"7\",\n",
+ " 8: \"8\",\n",
+ " 9: \"9\",\n",
+ " }\n",
+ "\n",
+ " fig, axes = plt.subplots(num_row, num_col,\n",
+ " figsize=figsize)\n",
+ "\n",
+ " for i in range(0, num_row * num_col):\n",
+ " sample_idx = object_index_start + i\n",
+ " img, label = subset_train[sample_idx]\n",
+ " ax = axes[i//num_col, i%num_col]\n",
+ " img_temp = img[0, :, :]\n",
+ " img_temp = np.array(img_temp).flat\n",
+ " ax.hist(img_temp, bins=n_bins, color='gray')\n",
+ " ax.set_title(\"label (digit): \" + labels_map[label])\n",
+ " ax.set_xlabel(\"Pixel Values\")\n",
+ "\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def predict(dataloader, model, dataset_type):\n",
+ " \"\"\"Predict labels of inputs\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " dataloader: `numpy.ndarray`\n",
+ " training data images\n",
+ " model: `int`, optional\n",
+ " number of rows to plot\n",
+ " dataset_type: `int`, optional\n",
+ " number of columns to plot\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " y_prob_list: `numpy.ndarray`\n",
+ " probabilities for each class for each input\n",
+ " y_choice_list: `numpy.ndarray`\n",
+ " highest-probability class for each input\n",
+ " y_true_list: `numpy.ndarray`\n",
+ " true class for each input\n",
+ " x_list: `numpy.ndarray`\n",
+ " input\n",
+ " \"\"\"\n",
+ " size = len(dataloader.dataset)\n",
+ " num_batches = len(dataloader)\n",
+ " model.eval()\n",
+ "\n",
+ " y_prob_list = []\n",
+ " y_choice_list = []\n",
+ " y_true_list = []\n",
+ " x_list = []\n",
+ "\n",
+ " i = 0\n",
+ " loss, accuracy = 0, 0\n",
+ " with torch.no_grad():\n",
+ " for inputs, labels in dataloader:\n",
+ " inputs = inputs.to(device)\n",
+ " labels = labels.to(device)\n",
+ "\n",
+ " y = model(inputs)\n",
+ " y_prob = torch.softmax(y, dim=1)\n",
+ " y_choice = (torch.max(torch.exp(y), 1)[1]).data.cpu().numpy()\n",
+ "\n",
+ " loss += loss_fn(y, labels).item()\n",
+ " loss /= num_batches\n",
+ " accuracy_temp = y.argmax(1) == labels\n",
+ " accuracy += accuracy_temp.type(torch.float).sum().item()\n",
+ " accuracy /= size\n",
+ "\n",
+ " y_prob_list.append(y_prob.detach())\n",
+ " y_choice_list.append(y_choice)\n",
+ " y_true_list.append(labels)\n",
+ " x_list.append(inputs)\n",
+ " labels = labels.data.cpu().numpy()\n",
+ "\n",
+ " i += 1\n",
+ "\n",
+ " y_prob_list = np.array(y_prob_list)\n",
+ " y_choice_list = np.array(y_choice_list)\n",
+ " y_true_list = np.array(y_true_list)\n",
+ " x_list = np.array(x_list)\n",
+ "\n",
+ " y_prob_list = np.squeeze(y_prob_list)\n",
+ " y_choice_list = np.squeeze(y_choice_list)\n",
+ " y_true_list = np.squeeze(y_true_list)\n",
+ " x_list = np.squeeze(x_list)\n",
+ "\n",
+ " print(f\"{dataset_type : <10} data set ...\\\n",
+ " Accuracy: {(100*accuracy):>0.1f}%, Avg loss: {loss:>8f}\")\n",
+ "\n",
+ " return y_prob_list, y_choice_list, y_true_list, x_list"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plotPredictionHistogram(y_prediction_a, y_prediction_b=None,\n",
+ " y_prediction_c=None, n_classes=None,\n",
+ " n_objects_a=None, n_colors=None,\n",
+ " title_a=None, title_b=None,\n",
+ " title_c=None, label_a=None,\n",
+ " label_b=None, label_c=None,\n",
+ " alpha=1.0, figsize=(12, 5),\n",
+ " save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " xlabel_plot=\"Class label\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot histogram of predicted labels\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " y_prediction_a: `numpy.ndarray`\n",
+ " y_prediction_b: `numpy.ndarray`, optional\n",
+ " y_prediction_c: `numpy.ndarray`, optional\n",
+ " n_classes: `int`, optional\n",
+ " n_objects_a: `int`, optional\n",
+ " n_colors: `int`, optional\n",
+ " title_a: `string`, optional\n",
+ " title_b: `string`, optional\n",
+ " title_c: `string`, optional\n",
+ " label_a: `string`, optional\n",
+ " label_b: `string`, optional\n",
+ " label_c: `string`, optional\n",
+ " alpha: `float`, optional\n",
+ " transparency\n",
+ " figsize: `tuple`, optional\n",
+ " figure size\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " ndim = y_prediction_a.ndim\n",
+ "\n",
+ " if ndim == 2:\n",
+ " fig, (axa, axb, axc) = plt.subplots(1, 3, figsize=figsize)\n",
+ " fig.subplots_adjust(wspace=0.35)\n",
+ " elif ndim == 1:\n",
+ " fig, ax = plt.subplots(figsize=figsize)\n",
+ "\n",
+ " shape_a = np.shape(y_prediction_a)\n",
+ "\n",
+ " if n_objects_a is None:\n",
+ " n_objects_a = shape_a[0]\n",
+ "\n",
+ " if ndim == 2:\n",
+ " if n_classes is None:\n",
+ " n_classes = shape_a[1]\n",
+ " if n_colors is None:\n",
+ " n_colors = n_classes\n",
+ " elif ndim == 1:\n",
+ " if n_colors is None:\n",
+ " n_colors = 1\n",
+ "\n",
+ " if ndim == 2:\n",
+ " colors = cm.Purples(np.linspace(0, 1, n_colors))\n",
+ " xlabel = \"Probability for Each Class\"\n",
+ "\n",
+ " axa.set_ylim(0, n_objects_a)\n",
+ " axa.set_xlabel(xlabel)\n",
+ " axa.set_title(title_a)\n",
+ "\n",
+ " for i in np.arange(n_classes):\n",
+ " axa.hist(y_prediction_a[:, i], alpha=alpha,\n",
+ " color=colors[i], label=\"'\" + str(i) + \"'\")\n",
+ "\n",
+ " if y_prediction_b is not None:\n",
+ " shape_b = np.shape(y_prediction_b)\n",
+ " axb.set_ylim(0, shape_b[0])\n",
+ " axb.set_xlabel(xlabel)\n",
+ " axb.set_title(title_b)\n",
+ "\n",
+ " for i in np.arange(n_classes):\n",
+ " axb.hist(y_prediction_b[:, i], alpha=alpha,\n",
+ " color=colors[i], label=\"'\" + str(i) + \"'\")\n",
+ "\n",
+ " if y_prediction_c is not None:\n",
+ " shape_c = np.shape(y_prediction_c)\n",
+ " axc.set_ylim(0, shape_c[0])\n",
+ " axc.set_xlabel(xlabel)\n",
+ " axc.set_title(title_c)\n",
+ "\n",
+ " for i in np.arange(n_classes):\n",
+ " axc.hist(y_prediction_c[:, i], alpha=alpha,\n",
+ " color=colors[i], label=\"'\" + str(i) + \"'\")\n",
+ "\n",
+ " elif ndim == 1:\n",
+ " ya, xa, _ = plt.hist(y_prediction_a, alpha=alpha, color='orange',\n",
+ " label=label_a)\n",
+ " y_max_list = [max(ya)]\n",
+ "\n",
+ " if y_prediction_b is not None:\n",
+ " yb, xb, _ = plt.hist(y_prediction_b, alpha=alpha, color='green',\n",
+ " label=label_b)\n",
+ " y_max_list.append(max(yb))\n",
+ "\n",
+ " if y_prediction_c is not None:\n",
+ " yc, xc, _ = plt.hist(y_prediction_c, alpha=alpha, color='purple',\n",
+ " label=label_c)\n",
+ " y_max_list.append(max(yc))\n",
+ "\n",
+ " plt.ylim(0, np.max(y_max_list)*1.1)\n",
+ " plt.xlabel(xlabel_plot)\n",
+ "\n",
+ " plt.legend(loc='upper right')\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plotTrueLabelHistogram(y_prediction_a, y_prediction_b=None,\n",
+ " y_prediction_c=None, n_classes=None,\n",
+ " n_objects_a=None, n_colors=None,\n",
+ " title_a=None, title_b=None,\n",
+ " title_c=None, label_a=None,\n",
+ " label_b=None, label_c=None,\n",
+ " alpha=1.0, figsize=(12, 5),\n",
+ " save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " xlabel_plot=\"Class label\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot histogram of predicted labels\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " y_prediction_a: `numpy.ndarray`\n",
+ " y_prediction_b: `numpy.ndarray`, optional\n",
+ " y_prediction_c: `numpy.ndarray`, optional\n",
+ " n_classes: `int`, optional\n",
+ " n_objects_a: `int`, optional\n",
+ " n_colors: `int`, optional\n",
+ " title_a: `string`, optional\n",
+ " title_b: `string`, optional\n",
+ " title_c: `string`, optional\n",
+ " label_a: `string`, optional\n",
+ " label_b: `string`, optional\n",
+ " label_c: `string`, optional\n",
+ " alpha: `float`, optional\n",
+ " transparency\n",
+ " figsize: `tuple`, optional\n",
+ " figure size\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ "\n",
+ " ndim = y_prediction_a.ndim\n",
+ "\n",
+ " if ndim == 2:\n",
+ " fig, (axa, axb, axc) = plt.subplots(1, 3, figsize=figsize)\n",
+ " fig.subplots_adjust(wspace=0.35)\n",
+ " elif ndim == 1:\n",
+ " fig, ax = plt.subplots(figsize=figsize)\n",
+ "\n",
+ " shape_a = np.shape(y_prediction_a)\n",
+ "\n",
+ " if n_objects_a is None:\n",
+ " n_objects_a = shape_a[0]\n",
+ "\n",
+ " if ndim == 2:\n",
+ " if n_classes is None:\n",
+ " n_classes = shape_a[1]\n",
+ " if n_colors is None:\n",
+ " n_colors = n_classes\n",
+ " elif ndim == 1:\n",
+ " if n_colors is None:\n",
+ " n_colors = 1\n",
+ "\n",
+ " if ndim == 2:\n",
+ " colors = cm.Purples(np.linspace(0, 1, n_colors))\n",
+ " xlabel = \"Probability for Each Class\"\n",
+ "\n",
+ " axa.set_ylim(0, n_objects_a)\n",
+ " axa.set_xlabel(xlabel)\n",
+ " axa.set_title(title_a)\n",
+ "\n",
+ " for i in np.arange(n_classes):\n",
+ " axa.hist(y_prediction_a[:, i], alpha=alpha,\n",
+ " color=colors[i], label=\"'\" + str(i) + \"'\")\n",
+ "\n",
+ " if y_prediction_b is not None:\n",
+ " shape_b = np.shape(y_prediction_b)\n",
+ " axb.set_ylim(0, shape_b[0])\n",
+ " axb.set_xlabel(xlabel)\n",
+ " axb.set_title(title_b)\n",
+ "\n",
+ " for i in np.arange(n_classes):\n",
+ " axb.hist(y_prediction_b[:, i], alpha=alpha,\n",
+ " color=colors[i], label=\"'\" + str(i) + \"'\")\n",
+ "\n",
+ " if y_prediction_c is not None:\n",
+ " shape_c = np.shape(y_prediction_c)\n",
+ " axc.set_ylim(0, shape_c[0])\n",
+ " axc.set_xlabel(xlabel)\n",
+ " axc.set_title(title_c)\n",
+ "\n",
+ " for i in np.arange(n_classes):\n",
+ " axc.hist(y_prediction_c[:, i], alpha=alpha,\n",
+ " color=colors[i], label=\"'\" + str(i) + \"'\")\n",
+ "\n",
+ " elif ndim == 1:\n",
+ " ya, xa, _ = plt.hist(y_prediction_a, alpha=alpha, color='orange',\n",
+ " label=label_a)\n",
+ " y_max_list = [max(ya)]\n",
+ "\n",
+ " if y_prediction_b is not None:\n",
+ " yb, xb, _ = plt.hist(y_prediction_b, alpha=alpha, color='green',\n",
+ " label=label_b)\n",
+ " y_max_list.append(max(yb))\n",
+ "\n",
+ " if y_prediction_c is not None:\n",
+ " yc, xc, _ = plt.hist(y_prediction_c, alpha=alpha, color='purple',\n",
+ " label=label_c)\n",
+ " y_max_list.append(max(yc))\n",
+ "\n",
+ " plt.ylim(0, np.max(y_max_list)*1.1)\n",
+ " plt.xlabel(xlabel_plot)\n",
+ "\n",
+ " plt.legend(loc='upper right')\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plotLossHistory(history, figsize=(8, 5),\n",
+ " save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot loss history of the model as function of epoch\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " history: `keras.src.callbacks.history.History`\n",
+ " keras callback history object containing the losses at each epoch\n",
+ " figsize: `tuple`, optional\n",
+ " figure size\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=figsize)\n",
+ "\n",
+ " loss_tra = np.array(history['loss'])\n",
+ " loss_val = np.array(history['val_loss'])\n",
+ " loss_dif = loss_val - loss_tra\n",
+ "\n",
+ " ax1.plot(loss_tra, label='Training')\n",
+ " ax1.plot(loss_val, label='Validation')\n",
+ " ax1.legend()\n",
+ "\n",
+ " ax2.plot(loss_dif, color='red', label='residual')\n",
+ " ax2.axhline(y=0, color='grey', linestyle='dashed', label='zero bias')\n",
+ " ax2.sharex(ax1)\n",
+ " ax2.legend()\n",
+ "\n",
+ " ax1.set_title('Loss History')\n",
+ " ax1.set_ylabel('Loss')\n",
+ " ax2.set_ylabel('Loss Residual')\n",
+ " ax2.set_xlabel('Epoch')\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ "\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plotArrayImageConfusion(x_tra, y_tra, y_pred_tra_topchoice,\n",
+ " title_main=None, num=10,\n",
+ " save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot images of examples objects that are misclassified.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " x_tra: `numpy.ndarray`\n",
+ " training image data\n",
+ " y_tra: `numpy.ndarray`\n",
+ " training label data\n",
+ " y_pred_tra_topchoice: `numpy.ndarray`\n",
+ " top choice of the predicted labels\n",
+ " title_main: `string`, optional\n",
+ " title for the plot\n",
+ " num: `int`, optional\n",
+ " number of examples\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " num_row = 2\n",
+ " num_col = 5\n",
+ " images = x_tra[:num]\n",
+ " labels_true = y_tra[:num]\n",
+ " labels_pred = y_pred_tra_topchoice[:num]\n",
+ "\n",
+ " fig, axes = plt.subplots(num_row, num_col,\n",
+ " figsize=(1.5*num_col, 2*num_row))\n",
+ "\n",
+ " fig.patch.set_linewidth(5)\n",
+ " fig.patch.set_edgecolor('cornflowerblue')\n",
+ "\n",
+ " for i in range(num):\n",
+ " ax = axes[i//num_col, i%num_col]\n",
+ " ax.imshow(images[i], cmap='gray')\n",
+ " ax.set_title(r'True: {}'.format(labels_true[i]) + '\\n'\n",
+ " + 'Pred: {}'.format(labels_pred[i]))\n",
+ "\n",
+ " fig.suptitle(title_main)\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ "\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plotROCMulticlassOnevsrest(y_tra, y_tes, y_pred_tes,\n",
+ " figsize=(7, 7),\n",
+ " save_file=False,\n",
+ " file_prefix=\"ImageExamples\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot the one-vs-rest ROC curve\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " y_tra: `numpy.ndarray`\n",
+ " training data true labels\n",
+ " y_tes: `numpy.ndarray`\n",
+ " testing data true labels\n",
+ " y_pred_tes: `numpy.ndarray`\n",
+ " testing data predicted labels\n",
+ " figsize: `tuple`, optional\n",
+ " size of figure\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " fig, ax = plt.subplots(figsize=figsize)\n",
+ " n_classes = len(y_prob_tes[0])\n",
+ " label_target_list = np.linspace(0, n_classes-1, num=n_classes)\n",
+ " color_list = cm.rainbow(np.linspace(0, 0.5, n_classes))\n",
+ "\n",
+ " for label_target, color in zip(label_target_list, color_list): \n",
+ " label_binarizer = LabelBinarizer().fit(y_tra)\n",
+ " y_onehot_tes = label_binarizer.transform(y_tes)\n",
+ " class_id = np.flatnonzero(label_binarizer.classes_ == label_target)[0]\n",
+ " display = RocCurveDisplay.from_predictions(\n",
+ " y_onehot_tes[:, class_id],\n",
+ " y_pred_tes[:, class_id],\n",
+ " name=f\"{int(label_target)} vs the rest\",\n",
+ " color=color,\n",
+ " ax=ax,\n",
+ " plot_chance_level=(class_id == 9)\n",
+ " )\n",
+ " \n",
+ " _ = display.ax_.set(\n",
+ " xlabel=\"False Positive Rate\",\n",
+ " ylabel=\"True Positive Rate\",\n",
+ " title=\"ROC: One-vs-Rest\",\n",
+ " )\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ " plt.savefig(file_final, bbox_inches='tight')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plotArrayHistogramConfusion(x_tra, y_tra, y_pred_tra_topchoice,\n",
+ " title_main=None, num=10,\n",
+ " save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot histograms of pixel values for images that are misclassified.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " x_tra: `numpy.ndarray`\n",
+ " training image data\n",
+ " y_tra: `numpy.ndarray`\n",
+ " training label data\n",
+ " y_pred_tra_topchoice: `numpy.ndarray`\n",
+ " top choice of the predicted labels\n",
+ " title_main: `string`, optional\n",
+ " title of plot\n",
+ " num: `int`, optional\n",
+ " number of examples\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " n_bins = 10\n",
+ " num_row = 2\n",
+ " num_col = 5\n",
+ " images = x_tra[:num]\n",
+ " labels_true = y_tra[:num]\n",
+ " labels_pred = y_pred_tra_topchoice[:num]\n",
+ "\n",
+ " fig, axes = plt.subplots(num_row, num_col,\n",
+ " figsize=(1.5*num_col, 2*num_row))\n",
+ "\n",
+ " fig.patch.set_linewidth(5)\n",
+ " fig.patch.set_edgecolor('cornflowerblue')\n",
+ "\n",
+ " for i in range(num):\n",
+ " ax = axes[i//num_col, i%num_col]\n",
+ " images_temp = images[i, :, :].flat\n",
+ " ax.hist(images_temp, bins=n_bins, color='gray')\n",
+ " ax.set_title(r'True: {}'.format(labels_true[i]) + '\\n'\n",
+ " + 'Pred: {}'.format(labels_pred[i]))\n",
+ " ax.set_xlabel('Pixel Values')\n",
+ "\n",
+ " fig.suptitle(title_main)\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ "\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.6. Define paths for data and figures"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Neural network training (i.e., model fitting) typically requires many numerical experiments to achieve an ideal model. To facilitate the comparison of these experiments/models, it is helpful to organize data carefully. \n",
+ "\n",
+ "Set the variable `run_label` for each training run. \n",
+ "Set paths for the model weight parameters and diagnostic figures, and\n",
+ "save these paths in the dictionary `path_dict` to facilitate passing information to plotting functions.\n",
+ "Check whether the paths exist, and if not, create them."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "run_label = \"Run000\"\n",
+ "\n",
+ "path_temp = os.getenv(\"HOME\") + '/dp02_16a_temp'\n",
+ "path_dict = {'run_label': run_label,\n",
+ " 'dir_data_model': path_temp + \"/Models/\",\n",
+ " 'dir_data_figures': path_temp + \"/Figures/\",\n",
+ " 'dir_data_data': path_temp + \"/Data/\",\n",
+ " 'file_model_prefix': \"Model\",\n",
+ " 'file_figure_prefix': \"Figure\",\n",
+ " 'file_figure_suffix': \".png\",\n",
+ " 'file_model_suffix': \".pt\"\n",
+ " }\n",
+ "del path_temp\n",
+ "\n",
+ "if not os.path.exists(path_dict['dir_data_model']):\n",
+ " os.makedirs(path_dict['dir_data_model'])\n",
+ "\n",
+ "if not os.path.exists(path_dict['dir_data_figures']):\n",
+ " os.makedirs(path_dict['dir_data_figures'])\n",
+ "\n",
+ "if not os.path.exists(path_dict['dir_data_data']):\n",
+ " os.makedirs(path_dict['dir_data_data'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 2. Load and Prepare data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this section, you will encounter the following items regarding the dataset that we'll build a classification model for.\n",
+ "\n",
+ "1. obtaining the data: Normalizing the data and downloading the data;\n",
+ "5. splitting the data into different sets;\n",
+ "6. creating subsets of each of those with a smaller number of objects;\n",
+ "7. examining the raw data;\n",
+ "8. creating dataloaders for model training."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.1. Obtain the dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "`pytorch` has a simple function to download the MNIST data to your local server for free. While downloading, we use the `transforms` method to normalize the data; this makes the model training more efficient."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 2.1.1. Normalize data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Create a `transform` that is used to convert the data sets to tensors that can be used in `pytorch`. This also transforms all the data from the range $[0,255]$ to $[0.,1.]$."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "kMzao03zcF5i",
+ "outputId": "0401d0af-d607-436a-908f-385ffc85812c"
+ },
+ "outputs": [],
+ "source": [
+ "transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 2.1.2. Download data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Download the data from a remote reserver."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%time\n",
+ "dataset = torchvision.datasets.MNIST(root=path_dict['dir_data_data'], train=True,\n",
+ " download=True, transform=transform)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.2. Split data into training, validation, and testing"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Split for a proper 'blind' analysis and optimization of an AI model.\n",
+ "\n",
+ "There are three primary data sets used in model development and optimization:\n",
+ "\n",
+ "* `Training` (with filename tag `_tra`) data is used directly by the algorithm to update the parameters of the AI model -- e.g., the weights of the computational neurons on the edges in neural networks.\n",
+ "* `Validation` (`_val`) data is used indirectly to update the hyperparameters of the AI model -- e.g., the batch size (`batchsize`), the learning rate, or the layers in the architecture of a neural network. Each time the neural network has completed training with the training data, the human looks at those diagnostics when run on the training and the validation data.\n",
+ "* `Test(ing)` (`_tes`) data is only used when the model is trained and validated and will no longer be update or further trained. This is the data that you would consider to be the new data that has not been examined before -- e.g., newly observed data that has not been previously characterized."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We use most of the data for training to maximize the accuracy and generalization of the model. We use a small amount of validation data, because only a little bit is needed to check the model optimization during training. We also use only a small amount of testing data, assuming that new data sets to examine are smaller than existing training data sets."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Set the fractions for the training, validation, and test sets."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fraction_tra = 0.8\n",
+ "fraction_val = 0.1\n",
+ "fraction_tes = 0.1"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Split the data according to those fractions."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fraction_list = [fraction_tra, fraction_val, fraction_tes]\n",
+ "\n",
+ "data_tra_full, data_val_full, data_tes_full = \\\n",
+ " torch.utils.data.random_split(dataset, fraction_list)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Print the data set sizes to check how much data you're using to optimize the model. Each problem has a unique requirement regarding the amount of training data. It typically depends on the complexity of shapes in the images and the variation across classes. Without enough data in each class, the model will likely be at least somewhat predictive but have a sub-optimal final loss (i.e., average error). If a problem requires neural networks, then it will likely require at least hundreds of images. Also, there is an interplay between the amount of data, the complexity of the images, and the complexity and size of the model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"training:\", len(data_tra_full.indices))\n",
+ "print(\"validation:\", len(data_val_full.indices))\n",
+ "print(\"test\", len(data_tes_full.indices))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.3. Create a subset of the training data."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use only a subset of each data to make the training faster for this tutorial. Consider increasing the sizes of these data sets in your exploration of the tutorial elements."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "subset_size = 5000\n",
+ "\n",
+ "subset_indices_tra = np.arange(subset_size)\n",
+ "subset_indices_val = np.arange(subset_size)\n",
+ "subset_indices_tes = np.arange(subset_size)\n",
+ "\n",
+ "data_tra = Subset(data_tra_full, subset_indices_tra)\n",
+ "data_val = Subset(data_val_full, subset_indices_val)\n",
+ "data_tes = Subset(data_tes_full, subset_indices_tes)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.4. Examine raw data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Review the raw data shapes by looking at a single datum in the training set. Each datum in the training set is an image-label pair. The image is a tensor, and the label is an integer. The image size in part determines the depth (number of layers) of the neural network. This will be discussed in a later section on model training."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "sample_index = 0\n",
+ "image, label = data_tra[sample_index]\n",
+ "print(f\"The image shape is {image.shape}.\")\n",
+ "print(f\"The label for this image is {label}.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Print the image labels and their corresponding indices within a dataset object. This is useful for verifying that you understand your data set."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "{v: k for k, v in dataset.class_to_idx.items()}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot examples of the raw data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"Example_Image_Array\"\\\n",
+ " + \"_\" + path_dict['run_label']\n",
+ "\n",
+ "plotArrayImageExamples(data_tra,\n",
+ " num_row=3, num_col=3,\n",
+ " save_file=False,\n",
+ " object_index_start=0,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 1: Three rows of three images, each a handwritten number in white on a black background."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot the distributions of pixel values to understand the data further. All pixel data has been normalized, and pixels have values between 0 and 1 only.\n",
+ "\n",
+ "The distribution of pixel values matches the images shown above: mostly black (values near 0) pixels, with some white (values near 1), and a few grey (values in between 0 and 1)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"Example_Histogram_Array\"\\\n",
+ " + \"_\" + path_dict['run_label']\n",
+ "\n",
+ "plotArrayHistogramExamples(data_tra,\n",
+ " num_row=3, num_col=3,\n",
+ " save_file=False,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 2: Three rows of three histograms, each showing the distribution of pixel values (number of pixels of a given value) for the handwritten digit images shown in Figure 1."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2024-12-15T20:26:20.181221Z",
+ "iopub.status.busy": "2024-12-15T20:26:20.180895Z",
+ "iopub.status.idle": "2024-12-15T20:26:20.186166Z",
+ "shell.execute_reply": "2024-12-15T20:26:20.185452Z",
+ "shell.execute_reply.started": "2024-12-15T20:26:20.181199Z"
+ }
+ },
+ "source": [
+ "Examine the balance of classes in the data sets. If one class has more objects in the training set than other classes, the model will tend to be biased toward predicting that class. The classes don't have to be distributed perfectly uniformly, but a high degree of uniformity is preferable."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use the `plotPredictionHistogram` function to plot histograms of true label distributions by class."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "y_tra = []\n",
+ "y_val = []\n",
+ "y_tes = []\n",
+ "\n",
+ "for i in np.arange(subset_size):\n",
+ " image, label_tra = data_tra[i]\n",
+ " image, label_val = data_val[i]\n",
+ " image, label_tes = data_tes[i]\n",
+ " y_tra.append(label_tra)\n",
+ " y_val.append(label_val)\n",
+ " y_tes.append(label_tes)\n",
+ "\n",
+ "y_tra = np.array(y_tra)\n",
+ "y_val = np.array(y_val)\n",
+ "y_tes = np.array(y_tes)\n",
+ "\n",
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"Histograms_true_class\"\\\n",
+ " + \"_\" + path_dict['run_label']\n",
+ "\n",
+ "plotPredictionHistogram(y_tra,\n",
+ " y_prediction_b=y_val,\n",
+ " y_prediction_c=y_tes,\n",
+ " label_a=\"Training Set\",\n",
+ " label_b=\"Validation Set\",\n",
+ " label_c=\"Testing Set\",\n",
+ " figsize=(12, 5),\n",
+ " alpha=0.5,\n",
+ " xlabel_plot=\"True class label\",\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Figure 3**: The histograms of true class labels for each data set. Each histogram is for a different data set used during model training --- training data, validation data, and test data. Note that these are overlapping histograms, not stacked. Please compare to Figure 4, which shows histograms of the predicted class labels."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.5. Create Dataloaders for training"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Set the batch size, which will be used in dataloaders and in training. To simplify data-handling, we set the batch size to the size of the subset that we select. This means that there is one batch used in training."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "batch_size = subset_size"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Create dataloaders for the training, validation, and test set data loaders."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trainloader = torch.utils.data.DataLoader(data_tra, batch_size=batch_size,\n",
+ " shuffle=False)\n",
+ "\n",
+ "validloader = torch.utils.data.DataLoader(data_val, batch_size=batch_size,\n",
+ " shuffle=False)\n",
+ "\n",
+ "testloader = torch.utils.data.DataLoader(data_tes, batch_size=batch_size,\n",
+ " shuffle=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 3. Train the model: Convolutional Neural Network"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this section, you will encounter the following major items for training a neural network:\n",
+ "\n",
+ "1. basics of `pytorch` model-building;\n",
+ "2. defining the neural network model -- setting model hyperparameters, a glossary of terms of model architecture elements, glossary of terms for a network model class, creating a model class;\n",
+ "3. training the model."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3.1. Basics of pytorch"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2024-12-15T21:16:11.080821Z",
+ "iopub.status.busy": "2024-12-15T21:16:11.079938Z",
+ "iopub.status.idle": "2024-12-15T21:16:11.085835Z",
+ "shell.execute_reply": "2024-12-15T21:16:11.085097Z",
+ "shell.execute_reply.started": "2024-12-15T21:16:11.080794Z"
+ }
+ },
+ "source": [
+ "In `pytorch`, neural network models are defined as classes. This is slightly different than typical `tensorflow` usage, in which people build a `sequential` model or use a pre-built `model` class and add layers to that model. \n",
+ "\n",
+ "The other major difference between `pytorch` and `tensorflow` is the shape of the tensors and the inputs for each layer. \n",
+ "\n",
+ "In particular, in `pytorch` one has to explicitly match the output from one layer to the input of the next layer. Sometimes, this can be done with a calculation. But, more often, one must perform guess-and-check. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3.2. Define the neural network model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2.1. Define model hyperparameters"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Set the seed that for the random initial model weight values."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Zsc_pqZtWftJ",
+ "outputId": "8238aa09-2afd-47ef-c93e-3910b342fe38"
+ },
+ "outputs": [],
+ "source": [
+ "seed = 1729\n",
+ "new = torch.manual_seed(seed)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Define the loss function. We choose cross-entropy, which is the standard loss function for classification of discrete labels."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "loss_fn = nn.CrossEntropyLoss()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2.2. Glossary: model architecture elements"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Below is a list of important terms for elements of a neural network architecture.\n",
+ "\n",
+ " * `activation function`: a function within a neuron that takes inputs from a previous layer and produces an output -- usually, this function is non-linear. Examples include \"sigmoid,\" \"softmax,\" and \"reLu\" (rectified linear unit).\n",
+ " * `sigmoid`: an activation function that takes points from the Real line and maps them to the range [-1,1].\n",
+ " * `softmax`: an activation function that takes points from the Real line and maps them to the range [0,1]. This can be used to obtain a 'probability score.'\n",
+ " * `ReLU (rectified linear unit)`: an activation function that non-smoothly goes from 0 to some positive Real number when a threshold is reached for the input value. \n",
+ " * `weight`: the weight factor within an activation function.\n",
+ " * `bias`: the bias factor applied after an activation function.\n",
+ " * `layer`: one set of nodes/neurons that receive input data simultaneously.\n",
+ " * `linear (Dense) layer`: occurs due to the \"flattening\" of a higher-dimensional data vector, like an image. It only has an activation function -- as opposed to a convolutional layer which makes a convolution operation.\n",
+ " * `convolutional layer`: a layer that applies a convolution operation to an input sample."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2.3. Glossary: class that defines a network model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Typically, there are two functions within this class.\n",
+ "\n",
+ "The constructor function (`__init__`) defines the available layers of the network. These layers require specific settings related to the data set shapes. \n",
+ "* `Conv2d`: defines a two-dimensional convolutional layer. Four inputs are considered here:\n",
+ " * `in_channels` (required): the number of input channels.\n",
+ " * `out_channels` (required): the number of output channels.\n",
+ " * `kernel_size` (required): the size on one dimension of the convolutional kernel.\n",
+ " * `stride` (optional): the stride of the convolution\n",
+ "* `Dropout`: defines a dropout layer. The fraction of neuron weights that are set to zero. Typically, this is used when the model is overfitting -- i.e., when the validation loss as a function of epoch is consistently higher than the training loss as a function of epoch (Please see Figure 6 as an example).\n",
+ "* `Linear`: defines a linear layer. Two inputs are considered here:\n",
+ " * `in_features` (required): the size of the sample input to the layer.\n",
+ " * `out_features` (required): the size of the sample output from the layer.\n",
+ "\n",
+ "The function `forward` defines the order of operations during a forward pass of the model. During training, the `forward` function is applied to the input data to make predictions. After each round of predictions (epoch), the optimizer is engaged to take the difference between the true labels and the predicted labels and then use that difference to update the model weights. \n",
+ "\n",
+ "The `forward` function uses the layers defined in the constructor, as well as other layers that don't require inputs that depend on the data. These layers are defined in the `torch.nn.functional` submodule, which contains predefined functions for layers that operate directly on the data and don't require an instance of that layer. These `functional` layers are \n",
+ "* `relu`: applies the `ReLU` activation function. It requires one input, the sample from the previous layer.\n",
+ "* `max_pool2d`: the max pooling function is applied to the sample from the previous layer. It requires the input sample and the size of the kernel of the pooling.\n",
+ "* `flatten`: reshapes the sample input into a one-dimensional tensor."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2.4. Define the object class that represents the model "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class ConvNet(nn.Module):\n",
+ " def __init__(self):\n",
+ " super(ConvNet, self).__init__()\n",
+ " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
+ " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
+ " self.dropout1 = nn.Dropout(0.25)\n",
+ " self.dropout2 = nn.Dropout(0.5)\n",
+ " self.fc1 = nn.Linear(9216, 128)\n",
+ " self.fc2 = nn.Linear(128, 10)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x = self.conv1(x)\n",
+ " x = F.relu(x)\n",
+ " x = self.conv2(x)\n",
+ " x = F.relu(x)\n",
+ " x = F.max_pool2d(x, 2)\n",
+ " x = self.dropout1(x)\n",
+ " x = torch.flatten(x, 1)\n",
+ " x = self.fc1(x)\n",
+ " x = F.relu(x)\n",
+ " x = self.dropout2(x)\n",
+ " x = self.fc2(x)\n",
+ " output = x\n",
+ " return output"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Instantiate a neural network model object."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = ConvNet()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Define the device. This is important especially when a GPU is available. This is necessary because the model and the data get moved to that device.\n",
+ "\n",
+ "In our case, the device is a CPU because that's what is currently available on the RSP."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "device = torch.device('cuda') if torch.cuda.is_available()\\\n",
+ " else torch.device('cpu')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Put the model on the device where the computations will be performed.\n",
+ "\n",
+ "When placing the model on the device, it also shows a summary of the network architecture. Examine the shapes of the layers and the number of parameters in each layer. Too few parameters may prevent the model from being flexible enough to model the data. Too many parameters could lead to overfitting of the model (e.g., 'memorizing' the training data) and an unnecessarily high computational cost. Typically, models range from thousands to millions of parameters. There is an interplay between model size (number of parameters), the amount of training data, and the complexity of the images in the training data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"Model Summary:\\n\")\n",
+ "model.to(device)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2.5. Glossary: hyperparameters for the training schedule"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "* `learning rate`: A multiplicative factor defining the amount that the model weights will change in response to the size of the error (loss) in the model prediction for that epoch. The lower the learning rate, the smaller the change to the weights, and usually the longer it will take to train the model. However, if the learning rate is too high, the weights can change too quickly: then, the loss can fluctuate significantly and not decrease quickly.\n",
+ "* `momentum`: A multiplicative factor on the aggregate of previous gradients. This aggregate term is combined with the weight gradient term to define the total change in the value of the weights. The smaller the momentum, the lesser the influence of the aggregated gradients, and usually the longer it will take train the model.\n",
+ "* `optimizer`: The method/algorithm used to update the network weights. Stochastic Gradient Descent (SGD) is the most commonly used method.\n",
+ "* `epoch`: One loop of training the model. Each loop includes the entire data set (all the batches) once, and it includes at least one round of weight updates.\n",
+ "* `n_epochs`: The number of epochs to train the network."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2.6. Assign hyperparameters for the training schedule"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "learning_rate = 0.01\n",
+ "momentum = 0.9\n",
+ "optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)\n",
+ "n_epochs = 50"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3.3. Train the model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Set the model into a training context. This ensures that layers like \"batchnorm\" and \"dropout\" will be activated. In contrast, when the model is set to an evaluation context (later in this tutorial), those layers will be deactivated."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "model.train()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use a loop over epochs to incrementally optimize the network weight parameters. \n",
+ "\n",
+ "Use lists to track the loss values on the training data, the loss values on the testing data, and the accuracy values on the testing data. Define the \"history\" dictionary to hold those lists. We will visualize these later to study the fitting efficacy and generalization capacity of the network.\n",
+ "\n",
+ "For this tutorial, 50 epochs on the medium-memory Rubin server required ~9 min of wall time and ~15min of CPU time."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "zB1OY3o8VWF_",
+ "outputId": "7bdd7e85-cef2-47af-f16b-5f8bbd66f67f"
+ },
+ "outputs": [],
+ "source": [
+ "%%time\n",
+ "time_start = time.time()\n",
+ "\n",
+ "loss_train_list = []\n",
+ "loss_test_list = []\n",
+ "accuracy_test_list = []\n",
+ "\n",
+ "for epoch in np.arange(n_epochs):\n",
+ "\n",
+ " loss_train = 0\n",
+ " for inputs, labels in trainloader:\n",
+ " inputs = inputs.to(device)\n",
+ " labels = labels.to(device)\n",
+ " y_pred = model(inputs)\n",
+ " loss = loss_fn(y_pred, labels)\n",
+ " optimizer.zero_grad()\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ " loss_train += loss.item()\n",
+ "\n",
+ " loss_train /= batch_size\n",
+ "\n",
+ " accuracy_test = 0\n",
+ " loss_test = 0\n",
+ " count_test = 0\n",
+ " for inputs, labels in testloader:\n",
+ " inputs = inputs.to(device)\n",
+ " labels = labels.to(device)\n",
+ " y_pred = model(inputs)\n",
+ " loss = loss_fn(y_pred, labels)\n",
+ " accuracy_test += (torch.argmax(y_pred, 1) == labels).float().sum()\n",
+ " loss_test += loss.item()\n",
+ " count_test += len(labels)\n",
+ "\n",
+ " accuracy_test /= count_test\n",
+ " loss_test /= batch_size\n",
+ "\n",
+ " loss_train_list.append(loss_train)\n",
+ " loss_test_list.append(loss_test)\n",
+ " accuracy_test_list.append(accuracy_test)\n",
+ "\n",
+ " # output = f\"Epoch ({epoch:3d}): accuracy ({accuracy_test*100:.2f} %),\\\n",
+ " output1 = f\"Epoch ({epoch:2d}): accuracy ({accuracy_test*50:.2f} %), \"\n",
+ " output2 = f\"train loss ({loss_train:.4f}), valid loss ({loss_test:.4f})\"\n",
+ " output = output1 + output2\n",
+ " print(output)\n",
+ "\n",
+ "time_end = time.time()\n",
+ "\n",
+ "time_difference = time_end - time_start\n",
+ "\n",
+ "history = {\"loss\": loss_train_list,\n",
+ " \"val_loss\": loss_test_list,\n",
+ " \"accuracy_test\": accuracy_test_list}\n",
+ "\n",
+ "print(f\"Total training time: {time_difference/60.:2.4f} min\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Save model to file: \"pt\" is the common suffix used for pytorch model files. This saves the architecture and weights of the model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "file_prefix = path_dict['file_model_prefix'] + \"_\" + path_dict['run_label']\n",
+ "file_name_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_model'],\n",
+ " file_suffix=path_dict['file_model_suffix'],\n",
+ " useuid=True,\n",
+ " verbose=True)\n",
+ "\n",
+ "torch.save(model.state_dict(), file_name_final)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Load the model from a \"pt\" file."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "model.load_state_dict(torch.load(file_name_final, weights_only=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-01-14T23:05:09.388295Z",
+ "iopub.status.busy": "2025-01-14T23:05:09.387717Z",
+ "iopub.status.idle": "2025-01-14T23:05:09.481703Z",
+ "shell.execute_reply": "2025-01-14T23:05:09.481106Z",
+ "shell.execute_reply.started": "2025-01-14T23:05:09.388255Z"
+ }
+ },
+ "source": [
+ "Set the model to evaluation mode so that the \"batchnorm\" and \"dropout\" layers are deactivated. This is necessary for consistent inference."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "model.eval()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 4. Diagnosing the Results of Model Training"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this section, you will encounter the following items:\n",
+ "\n",
+ "1. a glossary for diagnostics for model evaluation;\n",
+ "2. how to make predictions with the trained model;\n",
+ "3. basic bulk diagnostic: loss history;\n",
+ "4. basic bulk diagnostic: confusion matrix (CM);\n",
+ "5. basic bulk diagnostic: receiver operator characteristic (ROCH) curve;\n",
+ "6. investigating predictions of individual images."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.1. Glossary: diagnostics and model evaluation"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use the following diagnostics to assess the status of the network optimization and efficacy. The [scikit-learn page on metrics and scoring](https://scikit-learn.org/stable/modules/model_evaluation.html) provides a good in-depth reference for these terms.\n",
+ "\n",
+ "Model Predictions:\n",
+ " * `Classification threshold`: The user-chosen value $[0,1]$ that sets the threshold for a positive classification. \n",
+ " * `Probability score`: The output from the classifier neural network. One typically uses the `softmax` activation function in the last layer of the NN to provide an output in the range $[0,1]$.\n",
+ " * `Classification score`: The predicted class label is the class that received the highest probability score.\n",
+ "\n",
+ "Metrics:\n",
+ " * `Loss`: A function of the difference between the true labels and predicted labels.\n",
+ " * `Accuracy`: A rough indicator of model training progress/convergence for balanced datasets. For model performance, use only in combination with other metrics. Avoid this metric when you have unbalanced training datasets. Consider using another metric.\n",
+ " * `True Positive Rate (TPR; \"Recall\")`: Use when false negatives are more expensive than false positives.\n",
+ " * `False Positive Rate (FPR)`: Use when false positives are more expensive than false negatives.\n",
+ " * `Precision`: Use when positive predictions need to be accurate.\n",
+ "\n",
+ "The `Generalization Error` (GE) is the difference in loss when the model is applied to training data versus when applied to validation and test data.\n",
+ "\n",
+ "The `Confusion Matrix` (CM) is a visual representation of the classification accuracy. Each row is the set of predictions for each true value (with one true value per column). Values along the diagonal indicate true positives (correct predictions). Values below the diagonal indicate false positives. Values above the diagonal indicate false negatives. The optimal scenario is one in which the off-diagonal values are all zero.\n",
+ "\n",
+ "The `Receiver Operator Characteristic (ROC) Curve` presents a comparison between the true positive rate (y-axis) and the false positive rate (x-axis) --- for a given false positive rate, the number of true positives that exist. Each point on the curve is for a distinct choice of the Classification threshold for the probability score $[0,1]$: the choice of classification threshold determines which objects are considered correctly classified. The optimal scenario is where the ROC curve is constant at a true positive rate $=1$. If the curve is along the diagonal (lower left to upper right), it indicates that the model performance is equivalent to 50-50 guessing."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.2. Predict classifications with the trained model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Predict classification probabilities on the training, validation, and test sets. Produce both the probabilities of each digit and the top choice for each prediction."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "y_prob_tes, y_choice_tes, y_true_tes, x_tes = predict(testloader, model,\n",
+ " \"test\")\n",
+ "y_prob_val, y_choice_val, y_true_val, x_val = predict(validloader, model,\n",
+ " \"validation\")\n",
+ "y_prob_tra, y_choice_tra, y_true_tra, x_tra = predict(trainloader, model,\n",
+ " \"training\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Print the shapes and verify that the shape of `y_prob_tes` matches the length of the input data `x_tes` and the number of classes. (as in Section 2.4)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(f\"The input data has the shape {np.shape(x_tes)}: there the {subset_size}\\\n",
+ " images each image has 28 pixels on a side.\")\n",
+ "print(f\"The predicted probability score array has the shape\\\n",
+ " {np.shape(y_prob_tes)}: there are {subset_size} predictions,\\\n",
+ " with 10 probability scores predicted for each input image.\")\n",
+ "print(f\"The predicted classes array has has the shape {np.shape(y_choice_tes)}:\\\n",
+ " there is one top choice (highest probability score) for each prediction.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use the `plotPredictionHistogram` function to plot histograms of prediction distributions by class."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"Histograms_top_choice\"\\\n",
+ " + \"_\" + path_dict['run_label']\n",
+ "\n",
+ "plotPredictionHistogram(y_choice_tra,\n",
+ " y_prediction_b=y_choice_val,\n",
+ " y_prediction_c=y_choice_tes,\n",
+ " label_a=\"Training Set\",\n",
+ " label_b=\"Validation Set\",\n",
+ " label_c=\"Testing Set\",\n",
+ " figsize=(12, 5),\n",
+ " alpha=0.5,\n",
+ " xlabel_plot=\"Predicted class label\",\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 4: Histograms of the number of images for which the top-choice class was each number 0 through 9. Each histogram is for a different data set used during model training --- training data, validation data, and test data. Note that these are overlapping histograms, not stacked. Please compare to Figure 3, which shows the distributions of true class labels for each data set. Consider which classes are represented differently between the true labels and the predicted labels."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"Histograms_class_probabilities\"\\\n",
+ " + \"_\" + path_dict['run_label']\n",
+ "\n",
+ "plotPredictionHistogram(y_prob_tra,\n",
+ " y_prediction_b=y_prob_val,\n",
+ " y_prediction_c=y_prob_tes,\n",
+ " title_a='Training Set',\n",
+ " title_b='Validation Set',\n",
+ " title_c='Testing Set',\n",
+ " figsize=(15, 4),\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 5: Histograms of the number of images (y-axis) that had a probability (x-axis) of being each class 0 through 9 (light to dark shades)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In both of Figure 3 and 4, the histograms show very similar shapes across the classification categories.\n",
+ "This is a good sign because it indicates the model is not heavily biased toward a particular class."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.3. Loss History: History of Loss and Accuracy during Training"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The primary task in optimizing a network is to minimize the Generalization Error. Plot the loss history for the validation and training sets. We reserve the test set for a 'blind' analysis."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"LossHistory\"\\\n",
+ " + \"_\"\\\n",
+ " + path_dict['run_label']\n",
+ "\n",
+ "plotLossHistory(history,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 6: The loss histories for the training and validation data sets and the loss residual history between the training and the validation set. Top panel: the loss history as a function of epoch for the training and validation sets decreases with time, as it should as the model fit improves. There is slight difference between the validation and the training. Bottom panel: the loss residual (validation loss minus training loss) shows a minimum near epoch 34, indicating the model-fitting underwent a divergence between the classifications, but that this was rectified in later epochs. For epochs 10 to 44, the difference between validation and training data indicates that the validation data may not have been represeentative of the training data."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Another pattern you may encounter is when the validation loss is consistently higher than the training loss. This typically indicates overfitting, which is when the model is complex enough to fit (sometimes, the term 'memorize' is used) all the details of the training data and not generalize enough to also fit the validation data."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2024-10-26T18:50:41.751869Z",
+ "iopub.status.busy": "2024-10-26T18:50:41.751239Z",
+ "iopub.status.idle": "2024-10-26T18:50:41.757103Z",
+ "shell.execute_reply": "2024-10-26T18:50:41.756503Z",
+ "shell.execute_reply.started": "2024-10-26T18:50:41.751843Z"
+ }
+ },
+ "source": [
+ "### 4.4. Confusion Matrix: Bias in Trained Model?"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Compute and plot the confusion matrices for the training, validation, and test samples (left, right, middle)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')\n",
+ "figsize = (15, 3)\n",
+ "linewidths = 0.01\n",
+ "linecolor = 'white'\n",
+ "ylabel = \"Predicted Label\"\n",
+ "xlabel = \"True Label\"\n",
+ "\n",
+ "cm_tra = confusion_matrix(y_true_tra, y_choice_tra)\n",
+ "cm_val = confusion_matrix(y_true_val, y_choice_val)\n",
+ "cm_tes = confusion_matrix(y_true_tes, y_choice_tes)\n",
+ "\n",
+ "df_cm_tra = pd.DataFrame(cm_tra / np.sum(cm_tra, axis=1)[:, None],\n",
+ " index=[i for i in classes],\n",
+ " columns=[i for i in classes])\n",
+ "df_cm_val = pd.DataFrame(cm_val / np.sum(cm_val, axis=1)[:, None],\n",
+ " index=[i for i in classes],\n",
+ " columns=[i for i in classes])\n",
+ "df_cm_tes = pd.DataFrame(cm_tes / np.sum(cm_tes, axis=1)[:, None],\n",
+ " index=[i for i in classes],\n",
+ " columns=[i for i in classes])\n",
+ "\n",
+ "fig, (axa, axb, axc) = plt.subplots(1, 3, figsize=figsize)\n",
+ "fig.subplots_adjust(wspace=0.5)\n",
+ "\n",
+ "ax1 = sns.heatmap(df_cm_tra, annot=False, linewidths=linewidths,\n",
+ " linecolor=linecolor, square=True, ax=axa)\n",
+ "_ = ax1.set(xlabel=xlabel, ylabel=ylabel, title=\"Training Data\")\n",
+ "\n",
+ "ax2 = sns.heatmap(df_cm_val, annot=False, linewidths=linewidths,\n",
+ " linecolor=linecolor, square=True, ax=axb)\n",
+ "_ = ax2.set(xlabel=xlabel, ylabel=ylabel, title=\"Validation Data\")\n",
+ "\n",
+ "ax3 = sns.heatmap(df_cm_tes, annot=False, linewidths=linewidths,\n",
+ " linecolor=linecolor, square=True, ax=axc)\n",
+ "_ = ax3.set(xlabel=xlabel, ylabel=ylabel, title=\"Test Data\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 7: Confusion matrices for the training, validation, and test data sets (left to right). Each cell in the matrix shows the fraction (see the color bars) of the total objects with that true label that have been predicted to be a given label. The diagonal represents images that were correctly classified. All off-diagonal cells represent false positives (lower left) or false negatives (upper right). Consider some examples. First, the class \"0\" is almost always predicted to be \"0\" with a lighter color in the top-most, left-most cell; all the other cells in the column are completely dark. Second, consider the cell that represents a prediction \"4\", when the true label is \"9\": that cell is not completely dark. A \"4\" has a similar morphology or shape as a \"9.\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.5. Receiver Operator Characteristic (ROC) Curve: Trade-offs between Completeness and Purity"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The Receiver Operator Characteristic (ROC) Curve is a plot of the True Positive Rate (TPR; y axis) versus the False Positive Rate (FPR; x axis).\n",
+ "\n",
+ "For a given classification threshold on the probability (canonically, 0.5), the balance of true positives and false positives will change."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"ROCCurve\"\\\n",
+ " + \"_\"\\\n",
+ " + path_dict['run_label']\n",
+ "\n",
+ "plotROCMulticlassOnevsrest(y_tra, y_tes, y_prob_tes,\n",
+ " save_file=True,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 8: ROC curve for classification of one digit (class) against the remaining nine digits (classes). All of the curves show a very high AUC, which means near-perfect classification. The black dashed line shows an ROC curve representing a 50-50 guess. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "vvAqrZwjVYBt"
+ },
+ "source": [
+ "### 4.5. Investigating predictions in detail"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 4.5.1. Glossary: classification metrics"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-01-26T21:19:35.897282Z",
+ "iopub.status.busy": "2025-01-26T21:19:35.896488Z",
+ "iopub.status.idle": "2025-01-26T21:19:36.705672Z",
+ "shell.execute_reply": "2025-01-26T21:19:36.705022Z",
+ "shell.execute_reply.started": "2025-01-26T21:19:35.897242Z"
+ }
+ },
+ "source": [
+ "Consider the example of true class label being \"2\". Then, we define the following metrics.\n",
+ "\n",
+ "* `True Positive (TP)`: correctly classified input digit image --- e.g., a \"2\" classified as \"2\".\n",
+ "* `False Positive (FP)`: another digit classified as \"2\".\n",
+ "* `True Negative (TN)`: another digit classified as another digit.\n",
+ "* `False Negative (FN)`: a \"2\" classified as something other than \"2\"."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 4.5.2. Explore the classification of the training data for an example class value"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Investigate the case in which the true digit label is \"2\"."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class_value = 2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Find all objects that have that class value. Obtain indices for the TP's, FP's, TN's, and FN's. Create subsets of the data according to those indices."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ind_class_tp_tra = np.where((y_true_tra == class_value)\n",
+ " & (y_choice_tra == class_value))[0]\n",
+ "\n",
+ "ind_class_fp_tra = np.where((y_true_tra != class_value)\n",
+ " & (y_choice_tra == class_value))[0]\n",
+ "\n",
+ "ind_class_tn_tra = np.where((y_true_tra != class_value)\n",
+ " & (y_choice_tra != class_value))[0]\n",
+ "\n",
+ "ind_class_fn_tra = np.where((y_true_tra == class_value)\n",
+ " & (y_choice_tra != class_value))[0]\n",
+ "\n",
+ "x_tra_tp = x_tra[ind_class_tp_tra]\n",
+ "y_true_tra_tp = y_true_tra[ind_class_tp_tra]\n",
+ "y_choice_tra_tp = y_choice_tra[ind_class_tp_tra]\n",
+ "\n",
+ "x_tra_fp = x_tra[ind_class_fp_tra]\n",
+ "y_true_tra_fp = y_true_tra[ind_class_fp_tra]\n",
+ "y_choice_tra_fp = y_choice_tra[ind_class_fp_tra]\n",
+ "\n",
+ "x_tra_tn = x_tra[ind_class_tn_tra]\n",
+ "y_true_tra_tn = y_true_tra[ind_class_tn_tra]\n",
+ "y_choice_tra_tn = y_choice_tra[ind_class_tn_tra]\n",
+ "\n",
+ "x_tra_fn = x_tra[ind_class_fn_tra]\n",
+ "y_true_tra_fn = y_true_tra[ind_class_fn_tra]\n",
+ "y_choice_tra_fn = y_choice_tra[ind_class_fn_tra]\n",
+ "\n",
+ "n_tp = len(ind_class_tp_tra)\n",
+ "n_fp = len(ind_class_fp_tra)\n",
+ "n_tn = len(ind_class_tn_tra)\n",
+ "n_fn = len(ind_class_fn_tra)\n",
+ "\n",
+ "print(f\"TP count: {n_tp:4d}\")\n",
+ "print(f\"FP count: {n_fp:4d}\")\n",
+ "print(f\"TN count: {n_tn:4d}\")\n",
+ "print(f\"FN count: {n_fn:4d}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if n_tp > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"ExampleImages_TruePostives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayImageConfusion(x_tra_tp,\n",
+ " y_true_tra_tp,\n",
+ " y_choice_tra_tp,\n",
+ " title_main=\"True Positives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_fp > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"ExampleImages_FalsePostives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayImageConfusion(x_tra_fp,\n",
+ " y_true_tra_fp,\n",
+ " y_choice_tra_fp,\n",
+ " title_main=\"False Positives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_tn > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"ExampleImages_TrueNegatives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayImageConfusion(x_tra_tn,\n",
+ " y_true_tra_tn,\n",
+ " y_choice_tra_tn,\n",
+ " title_main=\"True Negatives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_fn > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"ExampleImages_FalseNegatives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayImageConfusion(x_tra_fn,\n",
+ " y_true_tra_fn,\n",
+ " y_choice_tra_fn,\n",
+ " title_main=\"False Negatives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 9: Four panels of 10 images each, representing true positives (top), false positives (second), true negatives (third), and false negatives (bottom), for classification category 2.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot histograms of images pixels of true positives, false positives, true negatives, and false negatives."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if n_tp > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix']\\\n",
+ " + \"_\" + \"ExampleImages_TruePostives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayHistogramConfusion(x_tra_tp,\n",
+ " y_true_tra_tp,\n",
+ " y_choice_tra_tp,\n",
+ " title_main=\"True Positives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_fp > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix']\\\n",
+ " + \"_\" + \"ExampleImages_FalsePostives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayHistogramConfusion(x_tra_fp,\n",
+ " y_true_tra_fp,\n",
+ " y_choice_tra_fp,\n",
+ " title_main=\"False Positives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_tn > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix']\\\n",
+ " + \"_\" + \"ExampleImages_TrueNegatives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayHistogramConfusion(x_tra_tn,\n",
+ " y_true_tra_tn,\n",
+ " y_choice_tra_tn,\n",
+ " title_main=\"True Negatives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_fn > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix']\\\n",
+ " + \"_\" + \"ExampleImages_FalseNegatives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayHistogramConfusion(x_tra_fn,\n",
+ " y_true_tra_fn,\n",
+ " y_choice_tra_fn,\n",
+ " title_main=\"False Negatives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 10: Histograms of the pixel flux values for the images shown in Figure 8. Here, it is difficult to infer reasons for network classification errors like false positives and false negatives. If it is expected that a particular would have a particular distribution of pixel brightnesses, but the histogram for the predicted digit has a different distribution, that may help you identify a pattern. This may be a more useful diagnostic for physics-related data, like galaxy light profiles."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 4.5.3. Investigate morphological features of images with feature maps"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The last diagnostic of this tutorial is the `feature map`, which shows how that one layer affects or processes the input image after the weights have been set during the model training. The morphological elements in the feature maps are the image regions that are highlighed by the earliest layers of convolution and activation functions.\n",
+ "\n",
+ "The feature map is obtained by passing an input image through one layer of the trained neural network model. More specifically, we first choose an input datum (image) by setting the array index of interest. Then, we make the image a `tensor` with the appropriate number of dimensions using `unsqueeze`. Then, we transfer it to the `device`. Next, we choose the `conv1` layer defined the section above where we defined the `model`. Finally, we use the number of convolutional kernels (32) that exist in the `conv1` layer of the model. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "index_input_image = 1\n",
+ "number_conv_kernels = 32\n",
+ "\n",
+ "input_image = dataset.data[index_input_image].type(torch.float32)\n",
+ "input_image = input_image.clone().detach()\n",
+ "input_image = input_image.unsqueeze(0)\n",
+ "input_image = input_image.to(device)\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " feature_maps = model.conv1(input_image).cpu()\n",
+ "\n",
+ "fig, ax = plt.subplots(4, 8, sharex=True, sharey=True, figsize=(16, 8))\n",
+ "\n",
+ "for i in range(0, number_conv_kernels):\n",
+ " row, col = i//8, i%8\n",
+ " ax[row][col].imshow(feature_maps[i])\n",
+ "\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 11: `Feature maps` of one input from the `conv1` layer of the `model`. The units on the x- and y-axes are the pixel indices. There are 32 images because there are 32 convolutional kernels in the `conv1` layer.\n",
+ ">\n",
+ "> The feature maps derived from the trained model show which morphological features --- e.g., lines and edges --- are favored by the model. When the model is accurate, these features and feature maps will accurately reflect the input image. In this example, the handwritten digit is \"0,\" and the feature maps all appear as \"0's\". Some maps have more clearly defined edges and \"0\"-like features than others, because each convolutional kernel has a distinct weight parameter associated with it. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 5. Exercises for the Learner"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Each time you train a new model, re-run all the diagnostic plots.\n",
+ "\n",
+ "1. How do the loss and accuracy histories change when batch size is small or large? Why?\n",
+ "2. Does the NN take more or less time (more or fewer epochs) to converge if the input image data are normalized or not normalized? Why?\n",
+ "3. How does the size of the training set affect the model's accuracy and loss -- keeping the number of epochs the same? Why?\n",
+ "3. How does the random seed for the weight initialization affect the model's accuracy and loss -- keeping the number of epochs the same?\n",
+ "5. Increase and then decrease the number of weights in the NN by an order of magnitude. Train the NN for each of those models and record the times. How does the number of weights in the neural network affect the training time and model loss and accuracy?\n",
+ "6. Increase and then decrease the number of layers in the NN. Train the NN for each of those models and record the times. How does the number of weights in the neural network affect the training time and model loss and accuracy?\n",
+ "7. Make new ROC curves using the validation data and the training data. Are the results consistent with those from the test data?\n",
+ "8. Change the pytorch random seed, and re-run the training and model evaluation. Have the results changed?"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "machine_shape": "hm",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "LSST",
+ "language": "python",
+ "name": "lsst"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}