"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3.2. Define the neural network model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2.1. Define model hyperparameters"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Set the seed that for the random initial model weight values."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Zsc_pqZtWftJ",
+ "outputId": "8238aa09-2afd-47ef-c93e-3910b342fe38"
+ },
+ "outputs": [],
+ "source": [
+ "seed = 1729\n",
+ "new = torch.manual_seed(seed)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Define the loss function. We choose cross-entropy, which is the standard loss function for classification of discrete labels."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "loss_fn = nn.CrossEntropyLoss()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2.2. Glossary: model architecture elements"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Below is a list of important terms for elements of a neural network architecture.\n",
+ "\n",
+ " * `activation function`: a function within a neuron that takes inputs from a previous layer and produces an output -- usually, this function is non-linear. Examples include \"sigmoid,\" \"softmax,\" and \"reLu\" (rectified linear unit).\n",
+ " * `sigmoid`: an activation function that takes points from the Real line and maps them to the range [-1,1].\n",
+ " * `softmax`: an activation function that takes points from the Real line and maps them to the range [0,1]. This can be used to obtain a 'probability score.'\n",
+ " * `ReLU (rectified linear unit)`: an activation function that non-smoothly goes from 0 to some positive Real number when a threshold is reached for the input value. \n",
+ " * `weight`: the weight factor within an activation function.\n",
+ " * `bias`: the bias factor applied after an activation function.\n",
+ " * `layer`: one set of nodes/neurons that receive input data simultaneously.\n",
+ " * `linear (Dense) layer`: occurs due to the \"flattening\" of a higher-dimensional data vector, like an image. It only has an activation function -- as opposed to a convolutional layer which makes a convolution operation.\n",
+ " * `convolutional layer`: a layer that applies a convolution operation to an input sample."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2.3. Glossary: class that defines a network model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Typically, there are two functions within this class.\n",
+ "\n",
+ "The constructor function (`__init__`) defines the available layers of the network. These layers require specific settings related to the data set shapes. \n",
+ "* `Conv2d`: defines a two-dimensional convolutional layer. Four inputs are considered here:\n",
+ " * `in_channels` (required): the number of input channels.\n",
+ " * `out_channels` (required): the number of output channels.\n",
+ " * `kernel_size` (required): the size on one dimension of the convolutional kernel.\n",
+ " * `stride` (optional): the stride of the convolution\n",
+ "* `Dropout`: defines a dropout layer. The fraction of neuron weights that are set to zero. Typically, this is used when the model is overfitting -- i.e., when the validation loss as a function of epoch is consistently higher than the training loss as a function of epoch (Please see Figure 6 as an example).\n",
+ "* `Linear`: defines a linear layer. Two inputs are considered here:\n",
+ " * `in_features` (required): the size of the sample input to the layer.\n",
+ " * `out_features` (required): the size of the sample output from the layer.\n",
+ "\n",
+ "The function `forward` defines the order of operations during a forward pass of the model. During training, the `forward` function is applied to the input data to make predictions. After each round of predictions (epoch), the optimizer is engaged to take the difference between the true labels and the predicted labels and then use that difference to update the model weights. \n",
+ "\n",
+ "The `forward` function uses the layers defined in the constructor, as well as other layers that don't require inputs that depend on the data. These layers are defined in the `torch.nn.functional` submodule, which contains predefined functions for layers that operate directly on the data and don't require an instance of that layer. These `functional` layers are \n",
+ "* `relu`: applies the `ReLU` activation function. It requires one input, the sample from the previous layer.\n",
+ "* `max_pool2d`: the max pooling function is applied to the sample from the previous layer. It requires the input sample and the size of the kernel of the pooling.\n",
+ "* `flatten`: reshapes the sample input into a one-dimensional tensor."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2.4. Define the object class that represents the model "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class ConvNet(nn.Module):\n",
+ " def __init__(self):\n",
+ " super(ConvNet, self).__init__()\n",
+ " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
+ " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
+ " self.dropout1 = nn.Dropout(0.25)\n",
+ " self.dropout2 = nn.Dropout(0.5)\n",
+ " self.fc1 = nn.Linear(9216, 128)\n",
+ " self.fc2 = nn.Linear(128, 10)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x = self.conv1(x)\n",
+ " x = F.relu(x)\n",
+ " x = self.conv2(x)\n",
+ " x = F.relu(x)\n",
+ " x = F.max_pool2d(x, 2)\n",
+ " x = self.dropout1(x)\n",
+ " x = torch.flatten(x, 1)\n",
+ " x = self.fc1(x)\n",
+ " x = F.relu(x)\n",
+ " x = self.dropout2(x)\n",
+ " x = self.fc2(x)\n",
+ " output = x\n",
+ " return output"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Instantiate a neural network model object."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = ConvNet()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Define the device. This is important especially when a GPU is available. This is necessary because the model and the data get moved to that device.\n",
+ "\n",
+ "In our case, the device is a CPU because that's what is currently available on the RSP."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "device = torch.device('cuda') if torch.cuda.is_available()\\\n",
+ " else torch.device('cpu')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Put the model on the device where the computations will be performed.\n",
+ "\n",
+ "When placing the model on the device, it also shows a summary of the network architecture. Examine the shapes of the layers and the number of parameters in each layer. Too few parameters may prevent the model from being flexible enough to model the data. Too many parameters could lead to overfitting of the model (e.g., 'memorizing' the training data) and an unnecessarily high computational cost. Typically, models range from thousands to millions of parameters. There is an interplay between model size (number of parameters), the amount of training data, and the complexity of the images in the training data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"Model Summary:\\n\")\n",
+ "model.to(device)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2.5. Glossary: hyperparameters for the training schedule"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "* `learning rate`: A multiplicative factor defining the amount that the model weights will change in response to the size of the error (loss) in the model prediction for that epoch. The lower the learning rate, the smaller the change to the weights, and usually the longer it will take to train the model. However, if the learning rate is too high, the weights can change too quickly: then, the loss can fluctuate significantly and not decrease quickly.\n",
+ "* `momentum`: A multiplicative factor on the aggregate of previous gradients. This aggregate term is combined with the weight gradient term to define the total change in the value of the weights. The smaller the momentum, the lesser the influence of the aggregated gradients, and usually the longer it will take train the model.\n",
+ "* `optimizer`: The method/algorithm used to update the network weights. Stochastic Gradient Descent (SGD) is the most commonly used method.\n",
+ "* `epoch`: One loop of training the model. Each loop includes the entire data set (all the batches) once, and it includes at least one round of weight updates.\n",
+ "* `n_epochs`: The number of epochs to train the network."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2.6. Assign hyperparameters for the training schedule"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "learning_rate = 0.01\n",
+ "momentum = 0.9\n",
+ "optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)\n",
+ "n_epochs = 50"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3.3. Train the model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Set the model into a training context. This ensures that layers like \"batchnorm\" and \"dropout\" will be activated. In contrast, when the model is set to an evaluation context (later in this tutorial), those layers will be deactivated."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "model.train()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use a loop over epochs to incrementally optimize the network weight parameters. \n",
+ "\n",
+ "Use lists to track the loss values on the training data, the loss values on the testing data, and the accuracy values on the testing data. Define the \"history\" dictionary to hold those lists. We will visualize these later to study the fitting efficacy and generalization capacity of the network.\n",
+ "\n",
+ "For this tutorial, 50 epochs on the medium-memory Rubin server required ~9 min of wall time and ~15min of CPU time."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "zB1OY3o8VWF_",
+ "outputId": "7bdd7e85-cef2-47af-f16b-5f8bbd66f67f"
+ },
+ "outputs": [],
+ "source": [
+ "%%time\n",
+ "time_start = time.time()\n",
+ "\n",
+ "loss_train_list = []\n",
+ "loss_test_list = []\n",
+ "accuracy_test_list = []\n",
+ "\n",
+ "for epoch in np.arange(n_epochs):\n",
+ "\n",
+ " loss_train = 0\n",
+ " for inputs, labels in trainloader:\n",
+ " inputs = inputs.to(device)\n",
+ " labels = labels.to(device)\n",
+ " y_pred = model(inputs)\n",
+ " loss = loss_fn(y_pred, labels)\n",
+ " optimizer.zero_grad()\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ " loss_train += loss.item()\n",
+ "\n",
+ " loss_train /= batch_size\n",
+ "\n",
+ " accuracy_test = 0\n",
+ " loss_test = 0\n",
+ " count_test = 0\n",
+ " for inputs, labels in testloader:\n",
+ " inputs = inputs.to(device)\n",
+ " labels = labels.to(device)\n",
+ " y_pred = model(inputs)\n",
+ " loss = loss_fn(y_pred, labels)\n",
+ " accuracy_test += (torch.argmax(y_pred, 1) == labels).float().sum()\n",
+ " loss_test += loss.item()\n",
+ " count_test += len(labels)\n",
+ "\n",
+ " accuracy_test /= count_test\n",
+ " loss_test /= batch_size\n",
+ "\n",
+ " loss_train_list.append(loss_train)\n",
+ " loss_test_list.append(loss_test)\n",
+ " accuracy_test_list.append(accuracy_test)\n",
+ "\n",
+ " # output = f\"Epoch ({epoch:3d}): accuracy ({accuracy_test*100:.2f} %),\\\n",
+ " output1 = f\"Epoch ({epoch:2d}): accuracy ({accuracy_test*50:.2f} %), \"\n",
+ " output2 = f\"train loss ({loss_train:.4f}), valid loss ({loss_test:.4f})\"\n",
+ " output = output1 + output2\n",
+ " print(output)\n",
+ "\n",
+ "time_end = time.time()\n",
+ "\n",
+ "time_difference = time_end - time_start\n",
+ "\n",
+ "history = {\"loss\": loss_train_list,\n",
+ " \"val_loss\": loss_test_list,\n",
+ " \"accuracy_test\": accuracy_test_list}\n",
+ "\n",
+ "print(f\"Total training time: {time_difference/60.:2.4f} min\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Save model to file: \"pt\" is the common suffix used for pytorch model files. This saves the architecture and weights of the model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "file_prefix = path_dict['file_model_prefix'] + \"_\" + path_dict['run_label']\n",
+ "file_name_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_model'],\n",
+ " file_suffix=path_dict['file_model_suffix'],\n",
+ " useuid=True,\n",
+ " verbose=True)\n",
+ "\n",
+ "torch.save(model.state_dict(), file_name_final)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Load the model from a \"pt\" file."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "model.load_state_dict(torch.load(file_name_final, weights_only=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-01-14T23:05:09.388295Z",
+ "iopub.status.busy": "2025-01-14T23:05:09.387717Z",
+ "iopub.status.idle": "2025-01-14T23:05:09.481703Z",
+ "shell.execute_reply": "2025-01-14T23:05:09.481106Z",
+ "shell.execute_reply.started": "2025-01-14T23:05:09.388255Z"
+ }
+ },
+ "source": [
+ "Set the model to evaluation mode so that the \"batchnorm\" and \"dropout\" layers are deactivated. This is necessary for consistent inference."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "model.eval()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 4. Diagnosing the Results of Model Training"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this section, you will encounter the following items:\n",
+ "\n",
+ "1. a glossary for diagnostics for model evaluation;\n",
+ "2. how to make predictions with the trained model;\n",
+ "3. basic bulk diagnostic: loss history;\n",
+ "4. basic bulk diagnostic: confusion matrix (CM);\n",
+ "5. basic bulk diagnostic: receiver operator characteristic (ROCH) curve;\n",
+ "6. investigating predictions of individual images."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.1. Glossary: diagnostics and model evaluation"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use the following diagnostics to assess the status of the network optimization and efficacy. The [scikit-learn page on metrics and scoring](https://scikit-learn.org/stable/modules/model_evaluation.html) provides a good in-depth reference for these terms.\n",
+ "\n",
+ "Model Predictions:\n",
+ " * `Classification threshold`: The user-chosen value $[0,1]$ that sets the threshold for a positive classification. \n",
+ " * `Probability score`: The output from the classifier neural network. One typically uses the `softmax` activation function in the last layer of the NN to provide an output in the range $[0,1]$.\n",
+ " * `Classification score`: The predicted class label is the class that received the highest probability score.\n",
+ "\n",
+ "Metrics:\n",
+ " * `Loss`: A function of the difference between the true labels and predicted labels.\n",
+ " * `Accuracy`: A rough indicator of model training progress/convergence for balanced datasets. For model performance, use only in combination with other metrics. Avoid this metric when you have unbalanced training datasets. Consider using another metric.\n",
+ " * `True Positive Rate (TPR; \"Recall\")`: Use when false negatives are more expensive than false positives.\n",
+ " * `False Positive Rate (FPR)`: Use when false positives are more expensive than false negatives.\n",
+ " * `Precision`: Use when positive predictions need to be accurate.\n",
+ "\n",
+ "The `Generalization Error` (GE) is the difference in loss when the model is applied to training data versus when applied to validation and test data.\n",
+ "\n",
+ "The `Confusion Matrix` (CM) is a visual representation of the classification accuracy. Each row is the set of predictions for each true value (with one true value per column). Values along the diagonal indicate true positives (correct predictions). Values below the diagonal indicate false positives. Values above the diagonal indicate false negatives. The optimal scenario is one in which the off-diagonal values are all zero.\n",
+ "\n",
+ "The `Receiver Operator Characteristic (ROC) Curve` presents a comparison between the true positive rate (y-axis) and the false positive rate (x-axis) --- for a given false positive rate, the number of true positives that exist. Each point on the curve is for a distinct choice of the Classification threshold for the probability score $[0,1]$: the choice of classification threshold determines which objects are considered correctly classified. The optimal scenario is where the ROC curve is constant at a true positive rate $=1$. If the curve is along the diagonal (lower left to upper right), it indicates that the model performance is equivalent to 50-50 guessing."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.2. Predict classifications with the trained model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Predict classification probabilities on the training, validation, and test sets. Produce both the probabilities of each digit and the top choice for each prediction."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "y_prob_tes, y_choice_tes, y_true_tes, x_tes = predict(testloader, model,\n",
+ " \"test\")\n",
+ "y_prob_val, y_choice_val, y_true_val, x_val = predict(validloader, model,\n",
+ " \"validation\")\n",
+ "y_prob_tra, y_choice_tra, y_true_tra, x_tra = predict(trainloader, model,\n",
+ " \"training\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Print the shapes and verify that the shape of `y_prob_tes` matches the length of the input data `x_tes` and the number of classes. (as in Section 2.4)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(f\"The input data has the shape {np.shape(x_tes)}: there the {subset_size}\\\n",
+ " images each image has 28 pixels on a side.\")\n",
+ "print(f\"The predicted probability score array has the shape\\\n",
+ " {np.shape(y_prob_tes)}: there are {subset_size} predictions,\\\n",
+ " with 10 probability scores predicted for each input image.\")\n",
+ "print(f\"The predicted classes array has has the shape {np.shape(y_choice_tes)}:\\\n",
+ " there is one top choice (highest probability score) for each prediction.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use the `plotPredictionHistogram` function to plot histograms of prediction distributions by class."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"Histograms_top_choice\"\\\n",
+ " + \"_\" + path_dict['run_label']\n",
+ "\n",
+ "plotPredictionHistogram(y_choice_tra,\n",
+ " y_prediction_b=y_choice_val,\n",
+ " y_prediction_c=y_choice_tes,\n",
+ " label_a=\"Training Set\",\n",
+ " label_b=\"Validation Set\",\n",
+ " label_c=\"Testing Set\",\n",
+ " figsize=(12, 5),\n",
+ " alpha=0.5,\n",
+ " xlabel_plot=\"Predicted class label\",\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 4: Histograms of the number of images for which the top-choice class was each number 0 through 9. Each histogram is for a different data set used during model training --- training data, validation data, and test data. Note that these are overlapping histograms, not stacked. Please compare to Figure 3, which shows the distributions of true class labels for each data set. Consider which classes are represented differently between the true labels and the predicted labels."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"Histograms_class_probabilities\"\\\n",
+ " + \"_\" + path_dict['run_label']\n",
+ "\n",
+ "plotPredictionHistogram(y_prob_tra,\n",
+ " y_prediction_b=y_prob_val,\n",
+ " y_prediction_c=y_prob_tes,\n",
+ " title_a='Training Set',\n",
+ " title_b='Validation Set',\n",
+ " title_c='Testing Set',\n",
+ " figsize=(15, 4),\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 5: Histograms of the number of images (y-axis) that had a probability (x-axis) of being each class 0 through 9 (light to dark shades)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In both of Figure 3 and 4, the histograms show very similar shapes across the classification categories.\n",
+ "This is a good sign because it indicates the model is not heavily biased toward a particular class."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.3. Loss History: History of Loss and Accuracy during Training"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The primary task in optimizing a network is to minimize the Generalization Error. Plot the loss history for the validation and training sets. We reserve the test set for a 'blind' analysis."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"LossHistory\"\\\n",
+ " + \"_\"\\\n",
+ " + path_dict['run_label']\n",
+ "\n",
+ "plotLossHistory(history,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 6: The loss histories for the training and validation data sets and the loss residual history between the training and the validation set. Top panel: the loss history as a function of epoch for the training and validation sets decreases with time, as it should as the model fit improves. There is slight difference between the validation and the training. Bottom panel: the loss residual (validation loss minus training loss) shows a minimum near epoch 34, indicating the model-fitting underwent a divergence between the classifications, but that this was rectified in later epochs. For epochs 10 to 44, the difference between validation and training data indicates that the validation data may not have been represeentative of the training data."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Another pattern you may encounter is when the validation loss is consistently higher than the training loss. This typically indicates overfitting, which is when the model is complex enough to fit (sometimes, the term 'memorize' is used) all the details of the training data and not generalize enough to also fit the validation data."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2024-10-26T18:50:41.751869Z",
+ "iopub.status.busy": "2024-10-26T18:50:41.751239Z",
+ "iopub.status.idle": "2024-10-26T18:50:41.757103Z",
+ "shell.execute_reply": "2024-10-26T18:50:41.756503Z",
+ "shell.execute_reply.started": "2024-10-26T18:50:41.751843Z"
+ }
+ },
+ "source": [
+ "### 4.4. Confusion Matrix: Bias in Trained Model?"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Compute and plot the confusion matrices for the training, validation, and test samples (left, right, middle)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')\n",
+ "figsize = (15, 3)\n",
+ "linewidths = 0.01\n",
+ "linecolor = 'white'\n",
+ "ylabel = \"Predicted Label\"\n",
+ "xlabel = \"True Label\"\n",
+ "\n",
+ "cm_tra = confusion_matrix(y_true_tra, y_choice_tra)\n",
+ "cm_val = confusion_matrix(y_true_val, y_choice_val)\n",
+ "cm_tes = confusion_matrix(y_true_tes, y_choice_tes)\n",
+ "\n",
+ "df_cm_tra = pd.DataFrame(cm_tra / np.sum(cm_tra, axis=1)[:, None],\n",
+ " index=[i for i in classes],\n",
+ " columns=[i for i in classes])\n",
+ "df_cm_val = pd.DataFrame(cm_val / np.sum(cm_val, axis=1)[:, None],\n",
+ " index=[i for i in classes],\n",
+ " columns=[i for i in classes])\n",
+ "df_cm_tes = pd.DataFrame(cm_tes / np.sum(cm_tes, axis=1)[:, None],\n",
+ " index=[i for i in classes],\n",
+ " columns=[i for i in classes])\n",
+ "\n",
+ "fig, (axa, axb, axc) = plt.subplots(1, 3, figsize=figsize)\n",
+ "fig.subplots_adjust(wspace=0.5)\n",
+ "\n",
+ "ax1 = sns.heatmap(df_cm_tra, annot=False, linewidths=linewidths,\n",
+ " linecolor=linecolor, square=True, ax=axa)\n",
+ "_ = ax1.set(xlabel=xlabel, ylabel=ylabel, title=\"Training Data\")\n",
+ "\n",
+ "ax2 = sns.heatmap(df_cm_val, annot=False, linewidths=linewidths,\n",
+ " linecolor=linecolor, square=True, ax=axb)\n",
+ "_ = ax2.set(xlabel=xlabel, ylabel=ylabel, title=\"Validation Data\")\n",
+ "\n",
+ "ax3 = sns.heatmap(df_cm_tes, annot=False, linewidths=linewidths,\n",
+ " linecolor=linecolor, square=True, ax=axc)\n",
+ "_ = ax3.set(xlabel=xlabel, ylabel=ylabel, title=\"Test Data\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 7: Confusion matrices for the training, validation, and test data sets (left to right). Each cell in the matrix shows the fraction (see the color bars) of the total objects with that true label that have been predicted to be a given label. The diagonal represents images that were correctly classified. All off-diagonal cells represent false positives (lower left) or false negatives (upper right). Consider some examples. First, the class \"0\" is almost always predicted to be \"0\" with a lighter color in the top-most, left-most cell; all the other cells in the column are completely dark. Second, consider the cell that represents a prediction \"4\", when the true label is \"9\": that cell is not completely dark. A \"4\" has a similar morphology or shape as a \"9.\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.5. Receiver Operator Characteristic (ROC) Curve: Trade-offs between Completeness and Purity"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The Receiver Operator Characteristic (ROC) Curve is a plot of the True Positive Rate (TPR; y axis) versus the False Positive Rate (FPR; x axis).\n",
+ "\n",
+ "For a given classification threshold on the probability (canonically, 0.5), the balance of true positives and false positives will change."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"ROCCurve\"\\\n",
+ " + \"_\"\\\n",
+ " + path_dict['run_label']\n",
+ "\n",
+ "plotROCMulticlassOnevsrest(y_tra, y_tes, y_prob_tes,\n",
+ " save_file=True,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 8: ROC curve for classification of one digit (class) against the remaining nine digits (classes). All of the curves show a very high AUC, which means near-perfect classification. The black dashed line shows an ROC curve representing a 50-50 guess. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "vvAqrZwjVYBt"
+ },
+ "source": [
+ "### 4.5. Investigating predictions in detail"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 4.5.1. Glossary: classification metrics"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-01-26T21:19:35.897282Z",
+ "iopub.status.busy": "2025-01-26T21:19:35.896488Z",
+ "iopub.status.idle": "2025-01-26T21:19:36.705672Z",
+ "shell.execute_reply": "2025-01-26T21:19:36.705022Z",
+ "shell.execute_reply.started": "2025-01-26T21:19:35.897242Z"
+ }
+ },
+ "source": [
+ "Consider the example of true class label being \"2\". Then, we define the following metrics.\n",
+ "\n",
+ "* `True Positive (TP)`: correctly classified input digit image --- e.g., a \"2\" classified as \"2\".\n",
+ "* `False Positive (FP)`: another digit classified as \"2\".\n",
+ "* `True Negative (TN)`: another digit classified as another digit.\n",
+ "* `False Negative (FN)`: a \"2\" classified as something other than \"2\"."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 4.5.2. Explore the classification of the training data for an example class value"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Investigate the case in which the true digit label is \"2\"."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class_value = 2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Find all objects that have that class value. Obtain indices for the TP's, FP's, TN's, and FN's. Create subsets of the data according to those indices."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ind_class_tp_tra = np.where((y_true_tra == class_value)\n",
+ " & (y_choice_tra == class_value))[0]\n",
+ "\n",
+ "ind_class_fp_tra = np.where((y_true_tra != class_value)\n",
+ " & (y_choice_tra == class_value))[0]\n",
+ "\n",
+ "ind_class_tn_tra = np.where((y_true_tra != class_value)\n",
+ " & (y_choice_tra != class_value))[0]\n",
+ "\n",
+ "ind_class_fn_tra = np.where((y_true_tra == class_value)\n",
+ " & (y_choice_tra != class_value))[0]\n",
+ "\n",
+ "x_tra_tp = x_tra[ind_class_tp_tra]\n",
+ "y_true_tra_tp = y_true_tra[ind_class_tp_tra]\n",
+ "y_choice_tra_tp = y_choice_tra[ind_class_tp_tra]\n",
+ "\n",
+ "x_tra_fp = x_tra[ind_class_fp_tra]\n",
+ "y_true_tra_fp = y_true_tra[ind_class_fp_tra]\n",
+ "y_choice_tra_fp = y_choice_tra[ind_class_fp_tra]\n",
+ "\n",
+ "x_tra_tn = x_tra[ind_class_tn_tra]\n",
+ "y_true_tra_tn = y_true_tra[ind_class_tn_tra]\n",
+ "y_choice_tra_tn = y_choice_tra[ind_class_tn_tra]\n",
+ "\n",
+ "x_tra_fn = x_tra[ind_class_fn_tra]\n",
+ "y_true_tra_fn = y_true_tra[ind_class_fn_tra]\n",
+ "y_choice_tra_fn = y_choice_tra[ind_class_fn_tra]\n",
+ "\n",
+ "n_tp = len(ind_class_tp_tra)\n",
+ "n_fp = len(ind_class_fp_tra)\n",
+ "n_tn = len(ind_class_tn_tra)\n",
+ "n_fn = len(ind_class_fn_tra)\n",
+ "\n",
+ "print(f\"TP count: {n_tp:4d}\")\n",
+ "print(f\"FP count: {n_fp:4d}\")\n",
+ "print(f\"TN count: {n_tn:4d}\")\n",
+ "print(f\"FN count: {n_fn:4d}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if n_tp > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"ExampleImages_TruePostives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayImageConfusion(x_tra_tp,\n",
+ " y_true_tra_tp,\n",
+ " y_choice_tra_tp,\n",
+ " title_main=\"True Positives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_fp > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"ExampleImages_FalsePostives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayImageConfusion(x_tra_fp,\n",
+ " y_true_tra_fp,\n",
+ " y_choice_tra_fp,\n",
+ " title_main=\"False Positives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_tn > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"ExampleImages_TrueNegatives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayImageConfusion(x_tra_tn,\n",
+ " y_true_tra_tn,\n",
+ " y_choice_tra_tn,\n",
+ " title_main=\"True Negatives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_fn > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"ExampleImages_FalseNegatives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayImageConfusion(x_tra_fn,\n",
+ " y_true_tra_fn,\n",
+ " y_choice_tra_fn,\n",
+ " title_main=\"False Negatives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 9: Four panels of 10 images each, representing true positives (top), false positives (second), true negatives (third), and false negatives (bottom), for classification category 2.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot histograms of images pixels of true positives, false positives, true negatives, and false negatives."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if n_tp > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix']\\\n",
+ " + \"_\" + \"ExampleImages_TruePostives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayHistogramConfusion(x_tra_tp,\n",
+ " y_true_tra_tp,\n",
+ " y_choice_tra_tp,\n",
+ " title_main=\"True Positives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_fp > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix']\\\n",
+ " + \"_\" + \"ExampleImages_FalsePostives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayHistogramConfusion(x_tra_fp,\n",
+ " y_true_tra_fp,\n",
+ " y_choice_tra_fp,\n",
+ " title_main=\"False Positives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_tn > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix']\\\n",
+ " + \"_\" + \"ExampleImages_TrueNegatives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayHistogramConfusion(x_tra_tn,\n",
+ " y_true_tra_tn,\n",
+ " y_choice_tra_tn,\n",
+ " title_main=\"True Negatives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_fn > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix']\\\n",
+ " + \"_\" + \"ExampleImages_FalseNegatives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayHistogramConfusion(x_tra_fn,\n",
+ " y_true_tra_fn,\n",
+ " y_choice_tra_fn,\n",
+ " title_main=\"False Negatives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 10: Histograms of the pixel flux values for the images shown in Figure 8. Here, it is difficult to infer reasons for network classification errors like false positives and false negatives. If it is expected that a particular would have a particular distribution of pixel brightnesses, but the histogram for the predicted digit has a different distribution, that may help you identify a pattern. This may be a more useful diagnostic for physics-related data, like galaxy light profiles."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 4.5.3. Investigate morphological features of images with feature maps"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The last diagnostic of this tutorial is the `feature map`, which shows how that one layer affects or processes the input image after the weights have been set during the model training. The morphological elements in the feature maps are the image regions that are highlighed by the earliest layers of convolution and activation functions.\n",
+ "\n",
+ "The feature map is obtained by passing an input image through one layer of the trained neural network model. More specifically, we first choose an input datum (image) by setting the array index of interest. Then, we make the image a `tensor` with the appropriate number of dimensions using `unsqueeze`. Then, we transfer it to the `device`. Next, we choose the `conv1` layer defined the section above where we defined the `model`. Finally, we use the number of convolutional kernels (32) that exist in the `conv1` layer of the model. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "index_input_image = 1\n",
+ "number_conv_kernels = 32\n",
+ "\n",
+ "input_image = dataset.data[index_input_image].type(torch.float32)\n",
+ "input_image = input_image.clone().detach()\n",
+ "input_image = input_image.unsqueeze(0)\n",
+ "input_image = input_image.to(device)\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " feature_maps = model.conv1(input_image).cpu()\n",
+ "\n",
+ "fig, ax = plt.subplots(4, 8, sharex=True, sharey=True, figsize=(16, 8))\n",
+ "\n",
+ "for i in range(0, number_conv_kernels):\n",
+ " row, col = i//8, i%8\n",
+ " ax[row][col].imshow(feature_maps[i])\n",
+ "\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 11: `Feature maps` of one input from the `conv1` layer of the `model`. The units on the x- and y-axes are the pixel indices. There are 32 images because there are 32 convolutional kernels in the `conv1` layer.\n",
+ ">\n",
+ "> The feature maps derived from the trained model show which morphological features --- e.g., lines and edges --- are favored by the model. When the model is accurate, these features and feature maps will accurately reflect the input image. In this example, the handwritten digit is \"0,\" and the feature maps all appear as \"0's\". Some maps have more clearly defined edges and \"0\"-like features than others, because each convolutional kernel has a distinct weight parameter associated with it. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 5. Exercises for the Learner"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Each time you train a new model, re-run all the diagnostic plots.\n",
+ "\n",
+ "1. How do the loss and accuracy histories change when batch size is small or large? Why?\n",
+ "2. Does the NN take more or less time (more or fewer epochs) to converge if the input image data are normalized or not normalized? Why?\n",
+ "3. How does the size of the training set affect the model's accuracy and loss -- keeping the number of epochs the same? Why?\n",
+ "3. How does the random seed for the weight initialization affect the model's accuracy and loss -- keeping the number of epochs the same?\n",
+ "5. Increase and then decrease the number of weights in the NN by an order of magnitude. Train the NN for each of those models and record the times. How does the number of weights in the neural network affect the training time and model loss and accuracy?\n",
+ "6. Increase and then decrease the number of layers in the NN. Train the NN for each of those models and record the times. How does the number of weights in the neural network affect the training time and model loss and accuracy?\n",
+ "7. Make new ROC curves using the validation data and the training data. Are the results consistent with those from the test data?\n",
+ "8. Change the pytorch random seed, and re-run the training and model evaluation. Have the results changed?"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "machine_shape": "hm",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "LSST",
+ "language": "python",
+ "name": "lsst"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/DP02_16a_Introduction_to_AI.ipynb.ipynb b/DP02_16a_Introduction_to_AI.ipynb.ipynb
new file mode 100644
index 0000000..bceac82
--- /dev/null
+++ b/DP02_16a_Introduction_to_AI.ipynb.ipynb
@@ -0,0 +1,3212 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "
\n",
+ "
Introduction to AI-based Image Classification with Pytorch
\n",
+ "Contact author: Brian Nord
\n",
+ "Last verified to run: 2024-07-01
\n",
+ "LSST Science Pipelines version: Weekly 2024_16
\n",
+ "Container size: medium
\n",
+ "Targeted learning level: beginner
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Description:** An introduction to the classification of images with AI-based classification algorithms."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Skills:** Examine AI training data, prepare it for a classification task, perform classification with a neural network, and examine the diagnostics of the classification task."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**LSST Data Products:** None; MNIST data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Packages:** numpy, matplotlib, sklearn, pytorch"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Credits and Acknowledgments:** We thank Ryan Lau and Melissa Graham for feedback on the notebook."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Get Support:**\n",
+ "Find DP0-related documentation and resources at dp0.lsst.io. Questions are welcome as new topics in the Support - Data Preview 0 Category of the Rubin Community Forum. Rubin staff will respond to all questions posted there."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 1. Introduction\n",
+ "\n",
+ "This Jupyter Notebook introduces artificial intelligence (AI)-based image classification. It demonstrates how to perform a few key steps:\n",
+ "1. examine and prepare data for classification;\n",
+ "2. train an AI algorithm;\n",
+ "3. plot diagnostics of the training performance;\n",
+ "4. initially assess those diagnostics. \n",
+ "\n",
+ "AI is a class of algorithms for building statistical models. These algorithms primarily use data for training, as opposed to models that use analytic formulae or models that are based on physical reasoning. Machine learning is a subclass of algorithms -- e.g., random forests. Deep learning is a subclass of algorithms -- e.g., neural networks. \n",
+ "\n",
+ "In this notebook, we use `pytorch`, which is currently the library most often used in deep learning studies. `Pytorch` is more complicated to use in some ways --- i.e., requires acclimating to how it handles tensors, data sets, and model building. At the same time, there is more flexibility for overall model development: you can more easily get creative with model structures. \n",
+ "\n",
+ "Instead of using DP0 data, this tutorials uses [MNIST AI benchmarking data](https://en.wikipedia.org/wiki/MNIST_database), a large database of handwritten digits that is commonly used for training and testing machine learning algorithms. It is simple to understand, so that users who are new to deep learning can focus on learning AI. Later tutorials in this series will use stars and galaxies drawn from DP0 data."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.1. AI is math, not magic.\n",
+ "\n",
+ "AI is firmly based in math, computer science, and statistics. Additionally, some of the approaches are inspired by concepts or notions in biology (e.g., the computational neuron) and in physics (e.g., the Reverse Boltzmann Machine). \n",
+ "\n",
+ "Much of the jargon in AI is anthropomorphic, which can make it appear that some other than math is happening. For example, consider the following list of terms that are very often used in AI -- and what these terms actually mean mathematically.\n",
+ "\n",
+ "1. `learn` $\\rightarrow$ fit\n",
+ "2. `hallucinate`/`lie` $\\rightarrow$ predict incorrectly\n",
+ "3. `understand` $\\rightarrow$ model fit has converged\n",
+ "4. `cheat` $\\rightarrow$ more efficiently guesses the best weight parameters of the model\n",
+ "5. `believe` $\\rightarrow$ predict/infer based on statistical priors\n",
+ "\n",
+ "When we over-anthropomorphize these mathematical concepts, we obfuscate how they actually work. That makes it harder to build and refine models. That is, AI models are not 'learning' or 'understanding'; they are merely large-parameter models that are being fit to data, and AI includes novel methods for that fitting process. The only learning that's happening is what humans do with these models.\n",
+ "\n",
+ "Many of the most useful AI-related terms are defined throughout this tutorial."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.2. Import packages\n",
+ "\n",
+ "[`numpy`](https://numpy.org/) is used for computations and mathematical operations on multi-dimensional arrays.\n",
+ "\n",
+ "[`matplotlib`](https://matplotlib.org/) is a plot library. \n",
+ "\n",
+ "[`sklearn`](https://scikit-learn.org/stable/) is a library for machine learning.\n",
+ "\n",
+ "[`torch`](https://www.pytorch.org) is used for fast tensor operations --- often used for building neural network models."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:40.895553Z",
+ "iopub.status.busy": "2025-02-09T23:00:40.895367Z",
+ "iopub.status.idle": "2025-02-09T23:00:49.712635Z",
+ "shell.execute_reply": "2025-02-09T23:00:49.711587Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:40.895535Z"
+ },
+ "id": "JyaaGkFE8VOl"
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import time\n",
+ "import datetime\n",
+ "import os\n",
+ "\n",
+ "import matplotlib\n",
+ "import matplotlib.pyplot as plt\n",
+ "from matplotlib.pyplot import cm\n",
+ "from matplotlib.colors import LogNorm\n",
+ "import seaborn as sns\n",
+ "\n",
+ "from sklearn.metrics import confusion_matrix\n",
+ "\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.optim as optim\n",
+ "import torch.nn.functional as F\n",
+ "import torchvision\n",
+ "from torch.utils.data import Dataset, DataLoader, Subset, random_split"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.3. Define functions"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The following functions are defined and used throughout this notebook. It is not necessary to understand exactly what every function does to proceed with this tutorial. Execute all cells and move on to Section 2."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:49.714597Z",
+ "iopub.status.busy": "2025-02-09T23:00:49.713821Z",
+ "iopub.status.idle": "2025-02-09T23:00:49.863235Z",
+ "shell.execute_reply": "2025-02-09T23:00:49.862202Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:49.714555Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def normalizeInputs(x_temp, input_minimum, input_maximum):\n",
+ " \"\"\"Normalize a datum that is an input to the neural network\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " x_temp: `numpy.array`\n",
+ " image data\n",
+ " input_minimum: `float`\n",
+ " minimum value for normalization\n",
+ " input_maximum: `float`\n",
+ " maximum value for normalization\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " x_temp_norm: `numpy.array`\n",
+ " normalized image data\n",
+ " \"\"\"\n",
+ " x_temp_norm = (x_temp - input_minimum)/input_maximum\n",
+ " return x_temp_norm"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:49.866230Z",
+ "iopub.status.busy": "2025-02-09T23:00:49.865802Z",
+ "iopub.status.idle": "2025-02-09T23:00:49.998210Z",
+ "shell.execute_reply": "2025-02-09T23:00:49.997647Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:49.866185Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def createFileUidTimestamp():\n",
+ " \"\"\"Create a timestamp for a filename.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " None\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " file_uid_timestamp : `string`\n",
+ " String from date and time.\n",
+ " \"\"\"\n",
+ " file_uid_timestamp = datetime.datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
+ " return file_uid_timestamp"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:49.998965Z",
+ "iopub.status.busy": "2025-02-09T23:00:49.998779Z",
+ "iopub.status.idle": "2025-02-09T23:00:50.155942Z",
+ "shell.execute_reply": "2025-02-09T23:00:50.154901Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:49.998948Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def createFileName(file_prefix=\"\", file_location=\"Data/Sandbox/\",\n",
+ " file_suffix=\"\", useuid=True, verbose=True):\n",
+ " \"\"\"Create a file name.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " file_prefix: `string`\n",
+ " prefix of file name\n",
+ " file_location: `string`\n",
+ " path to file\n",
+ " file_suffix: `string`\n",
+ " suffix/extension of file name\n",
+ " useuid: 'bool'\n",
+ " choose to use a unique id\n",
+ " verbose: 'bool'\n",
+ " choose to print the file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " file_final: `string`\n",
+ " filename used for saving\n",
+ " \"\"\"\n",
+ " if useuid:\n",
+ " file_uid = createFileUidTimestamp()\n",
+ " else:\n",
+ " file_uid = \"\"\n",
+ "\n",
+ " file_final = file_location + file_prefix + \"_\" + file_uid + file_suffix\n",
+ "\n",
+ " if verbose:\n",
+ " print(file_final)\n",
+ "\n",
+ " return file_final"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:50.157469Z",
+ "iopub.status.busy": "2025-02-09T23:00:50.157104Z",
+ "iopub.status.idle": "2025-02-09T23:00:50.298215Z",
+ "shell.execute_reply": "2025-02-09T23:00:50.297620Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:50.157432Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def plotArrayImageExamples(subset_train,\n",
+ " num_row=3, num_col=3,\n",
+ " object_index_start=0,\n",
+ " save_file=False,\n",
+ " file_prefix=\"ImageExamples\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot an array of examples of images and labels\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " subset_train: `numpy.ndarray`\n",
+ " training data images\n",
+ " num_row: `int`, optional\n",
+ " number of rows to plot\n",
+ " num_col: `int`, optional\n",
+ " number of columns to plot\n",
+ " object_index_start: `int`, optional\n",
+ " starting index for set of images to plot\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " From: https://pytorch.org/tutorials/beginner/basics/data_tutorial.html\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " labels_map = {\n",
+ " 0: \"0\",\n",
+ " 1: \"1\",\n",
+ " 2: \"2\",\n",
+ " 3: \"3\",\n",
+ " 4: \"4\",\n",
+ " 5: \"5\",\n",
+ " 6: \"6\",\n",
+ " 7: \"7\",\n",
+ " 8: \"8\",\n",
+ " 9: \"9\",\n",
+ " }\n",
+ "\n",
+ " figure = plt.figure(figsize=(8, 8))\n",
+ "\n",
+ " for i in range(1, num_row * num_col + 1):\n",
+ " sample_idx = object_index_start + i\n",
+ " img, label = subset_train[sample_idx]\n",
+ " figure.add_subplot(num_row, num_col, i)\n",
+ " plt.title(labels_map[label])\n",
+ " plt.axis(\"off\")\n",
+ " plt.imshow(img.squeeze(), cmap=\"gray\")\n",
+ "\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:50.299193Z",
+ "iopub.status.busy": "2025-02-09T23:00:50.298861Z",
+ "iopub.status.idle": "2025-02-09T23:00:50.463522Z",
+ "shell.execute_reply": "2025-02-09T23:00:50.462425Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:50.299174Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "pycodestyle INFO: 20:35: W291 trailing whitespace
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "def plotArrayHistogramExamples(subset_train,\n",
+ " num_row=3, num_col=3,\n",
+ " n_bins=10,\n",
+ " object_index_start=0,\n",
+ " save_file=False,\n",
+ " file_prefix=\"HistogramExamples\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot histograms of image pixel values\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " subset_train: `numpy.ndarray`\n",
+ " training data images\n",
+ " num_row: `int`, optional\n",
+ " number of rows to plot\n",
+ " num_col: `int`, optional\n",
+ " number of columns to plot\n",
+ " n_bins: `int`, optional\n",
+ " number of bins in histogram \n",
+ " object_index_start: `int`, optional\n",
+ " starting index for set of images to plot\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " labels_map = {\n",
+ " 0: \"0\",\n",
+ " 1: \"1\",\n",
+ " 2: \"2\",\n",
+ " 3: \"3\",\n",
+ " 4: \"4\",\n",
+ " 5: \"5\",\n",
+ " 6: \"6\",\n",
+ " 7: \"7\",\n",
+ " 8: \"8\",\n",
+ " 9: \"9\",\n",
+ " }\n",
+ "\n",
+ " fig, axes = plt.subplots(num_row, num_col,\n",
+ " figsize=(1.5*num_col, 2*num_row))\n",
+ "\n",
+ " for i in range(0, num_row * num_col + 0):\n",
+ " sample_idx = object_index_start + i\n",
+ " img, label = subset_train[sample_idx]\n",
+ " ax = axes[i//num_col, i%num_col]\n",
+ " img_temp = img[0, :, :]\n",
+ " img_temp = np.array(img_temp).flat\n",
+ " ax.hist(img_temp, bins=n_bins, color='gray')\n",
+ " ax.set_title(labels_map[label])\n",
+ " ax.set_xlabel(\"Pixel Values\")\n",
+ "\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:50.465449Z",
+ "iopub.status.busy": "2025-02-09T23:00:50.465057Z",
+ "iopub.status.idle": "2025-02-09T23:00:50.612203Z",
+ "shell.execute_reply": "2025-02-09T23:00:50.611165Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:50.465413Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def predict(dataloader, model, dataset_type):\n",
+ " \"\"\"Predict labels of inputs\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " dataloader: `numpy.ndarray`\n",
+ " training data images\n",
+ " model: `int`, optional\n",
+ " number of rows to plot\n",
+ " dataset_type: `int`, optional\n",
+ " number of columns to plot\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " y_prob_list: `numpy.ndarray`\n",
+ " probabilities for each class for each input\n",
+ " y_choice_list: `numpy.ndarray`\n",
+ " highest-probability class for each input\n",
+ " y_true_list: `numpy.ndarray`\n",
+ " true class for each input\n",
+ " x_list: `numpy.ndarray`\n",
+ " input\n",
+ " \"\"\"\n",
+ " size = len(dataloader.dataset)\n",
+ " num_batches = len(dataloader)\n",
+ " model.eval()\n",
+ "\n",
+ " y_prob_list = []\n",
+ " y_choice_list = []\n",
+ " y_true_list = []\n",
+ " x_list = []\n",
+ "\n",
+ " i = 0\n",
+ " loss, accuracy = 0, 0\n",
+ " with torch.no_grad():\n",
+ " for inputs, labels in dataloader:\n",
+ " inputs = inputs.to(device)\n",
+ " labels = labels.to(device)\n",
+ "\n",
+ " y = model(inputs)\n",
+ " y_prob = torch.softmax(y, dim=1)\n",
+ " y_choice = (torch.max(torch.exp(y), 1)[1]).data.cpu().numpy()\n",
+ "\n",
+ " loss += loss_fn(y, labels).item()\n",
+ " loss /= num_batches\n",
+ " accuracy_temp = y.argmax(1) == labels\n",
+ " accuracy += accuracy_temp.type(torch.float).sum().item()\n",
+ " accuracy /= size\n",
+ "\n",
+ " y_prob_list.append(y_prob.detach())\n",
+ " y_choice_list.append(y_choice)\n",
+ " y_true_list.append(labels)\n",
+ " x_list.append(inputs)\n",
+ " labels = labels.data.cpu().numpy()\n",
+ "\n",
+ " i += 1\n",
+ "\n",
+ " y_prob_list = np.array(y_prob_list)\n",
+ " y_choice_list = np.array(y_choice_list)\n",
+ " y_true_list = np.array(y_true_list)\n",
+ " x_list = np.array(x_list)\n",
+ "\n",
+ " y_prob_list = np.squeeze(y_prob_list)\n",
+ " y_choice_list = np.squeeze(y_choice_list)\n",
+ " y_true_list = np.squeeze(y_true_list)\n",
+ " x_list = np.squeeze(x_list)\n",
+ "\n",
+ " print(f\"{dataset_type : <10} data set ...\\\n",
+ " Accuracy: {(100*accuracy):>0.1f}%, Avg loss: {loss:>8f}\")\n",
+ "\n",
+ " return y_prob_list, y_choice_list, y_true_list, x_list"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:50.614078Z",
+ "iopub.status.busy": "2025-02-09T23:00:50.613693Z",
+ "iopub.status.idle": "2025-02-09T23:00:50.840445Z",
+ "shell.execute_reply": "2025-02-09T23:00:50.839384Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:50.614040Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def plotPredictionHistogram(y_prediction_a, y_prediction_b=None,\n",
+ " y_prediction_c=None, n_classes=None,\n",
+ " n_objects_a=None, n_colors=None,\n",
+ " title_a=None, title_b=None,\n",
+ " title_c=None, label_a=None,\n",
+ " label_b=None, label_c=None,\n",
+ " alpha=1.0, figsize=(12, 5),\n",
+ " save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " xlabel_plot=\"Class label\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot histogram of predicted labels\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " y_prediction_a: `numpy.ndarray`\n",
+ " y_prediction_b: `numpy.ndarray`, optional\n",
+ " y_prediction_c: `numpy.ndarray`, optional\n",
+ " n_classes: `int`, optional\n",
+ " n_objects_a: `int`, optional\n",
+ " n_colors: `int`, optional\n",
+ " title_a: `string`, optional\n",
+ " title_b: `string`, optional\n",
+ " title_c: `string`, optional\n",
+ " label_a: `string`, optional\n",
+ " label_b: `string`, optional\n",
+ " label_c: `string`, optional\n",
+ " alpha: `float`, optional\n",
+ " transparency\n",
+ " figsize: `tuple`, optional\n",
+ " figure size\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " ndim = y_prediction_a.ndim\n",
+ "\n",
+ " if ndim == 2:\n",
+ " fig, (axa, axb, axc) = plt.subplots(1, 3, figsize=figsize)\n",
+ " fig.subplots_adjust(wspace=0.35)\n",
+ " elif ndim == 1:\n",
+ " fig, ax = plt.subplots(figsize=figsize)\n",
+ "\n",
+ " shape_a = np.shape(y_prediction_a)\n",
+ "\n",
+ " if n_objects_a is None:\n",
+ " n_objects_a = shape_a[0]\n",
+ "\n",
+ " if ndim == 2:\n",
+ " if n_classes is None:\n",
+ " n_classes = shape_a[1]\n",
+ " if n_colors is None:\n",
+ " n_colors = n_classes\n",
+ " elif ndim == 1:\n",
+ " if n_colors is None:\n",
+ " n_colors = 1\n",
+ "\n",
+ " if ndim == 2:\n",
+ " colors = cm.Purples(np.linspace(0, 1, n_colors))\n",
+ " xlabel = \"Probability for Each Class\"\n",
+ "\n",
+ " axa.set_ylim(0, n_objects_a)\n",
+ " axa.set_xlabel(xlabel)\n",
+ " axa.set_title(title_a)\n",
+ "\n",
+ " for i in np.arange(n_classes):\n",
+ " axa.hist(y_prediction_a[:, i], alpha=alpha,\n",
+ " color=colors[i], label=\"'\" + str(i) + \"'\")\n",
+ "\n",
+ " if y_prediction_b is not None:\n",
+ " shape_b = np.shape(y_prediction_b)\n",
+ " axb.set_ylim(0, shape_b[0])\n",
+ " axb.set_xlabel(xlabel)\n",
+ " axb.set_title(title_b)\n",
+ "\n",
+ " for i in np.arange(n_classes):\n",
+ " axb.hist(y_prediction_b[:, i], alpha=alpha,\n",
+ " color=colors[i], label=\"'\" + str(i) + \"'\")\n",
+ "\n",
+ " if y_prediction_c is not None:\n",
+ " shape_c = np.shape(y_prediction_c)\n",
+ " axc.set_ylim(0, shape_c[0])\n",
+ " axc.set_xlabel(xlabel)\n",
+ " axc.set_title(title_c)\n",
+ "\n",
+ " for i in np.arange(n_classes):\n",
+ " axc.hist(y_prediction_c[:, i], alpha=alpha,\n",
+ " color=colors[i], label=\"'\" + str(i) + \"'\")\n",
+ "\n",
+ " elif ndim == 1:\n",
+ " ya, xa, _ = plt.hist(y_prediction_a, alpha=alpha, color='orange',\n",
+ " label=label_a)\n",
+ " y_max_list = [max(ya)]\n",
+ "\n",
+ " if y_prediction_b is not None:\n",
+ " yb, xb, _ = plt.hist(y_prediction_b, alpha=alpha, color='green',\n",
+ " label=label_b)\n",
+ " y_max_list.append(max(yb))\n",
+ "\n",
+ " if y_prediction_c is not None:\n",
+ " yc, xc, _ = plt.hist(y_prediction_c, alpha=alpha, color='purple',\n",
+ " label=label_c)\n",
+ " y_max_list.append(max(yc))\n",
+ "\n",
+ " plt.ylim(0, np.max(y_max_list)*1.1)\n",
+ " plt.xlabel(xlabel_plot)\n",
+ "\n",
+ " plt.legend(loc='upper right')\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:50.842666Z",
+ "iopub.status.busy": "2025-02-09T23:00:50.842278Z",
+ "iopub.status.idle": "2025-02-09T23:00:51.020147Z",
+ "shell.execute_reply": "2025-02-09T23:00:51.019064Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:50.842629Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def plotTrueLabelHistogram(y_prediction_a, y_prediction_b=None,\n",
+ " y_prediction_c=None, n_classes=None,\n",
+ " n_objects_a=None, n_colors=None,\n",
+ " title_a=None, title_b=None,\n",
+ " title_c=None, label_a=None,\n",
+ " label_b=None, label_c=None,\n",
+ " alpha=1.0, figsize=(12, 5),\n",
+ " save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " xlabel_plot=\"Class label\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot histogram of predicted labels\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " y_prediction_a: `numpy.ndarray`\n",
+ " y_prediction_b: `numpy.ndarray`, optional\n",
+ " y_prediction_c: `numpy.ndarray`, optional\n",
+ " n_classes: `int`, optional\n",
+ " n_objects_a: `int`, optional\n",
+ " n_colors: `int`, optional\n",
+ " title_a: `string`, optional\n",
+ " title_b: `string`, optional\n",
+ " title_c: `string`, optional\n",
+ " label_a: `string`, optional\n",
+ " label_b: `string`, optional\n",
+ " label_c: `string`, optional\n",
+ " alpha: `float`, optional\n",
+ " transparency\n",
+ " figsize: `tuple`, optional\n",
+ " figure size\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ "\n",
+ " ndim = y_prediction_a.ndim\n",
+ "\n",
+ " if ndim == 2:\n",
+ " fig, (axa, axb, axc) = plt.subplots(1, 3, figsize=figsize)\n",
+ " fig.subplots_adjust(wspace=0.35)\n",
+ " elif ndim == 1:\n",
+ " fig, ax = plt.subplots(figsize=figsize)\n",
+ "\n",
+ " shape_a = np.shape(y_prediction_a)\n",
+ "\n",
+ " if n_objects_a is None:\n",
+ " n_objects_a = shape_a[0]\n",
+ "\n",
+ " if ndim == 2:\n",
+ " if n_classes is None:\n",
+ " n_classes = shape_a[1]\n",
+ " if n_colors is None:\n",
+ " n_colors = n_classes\n",
+ " elif ndim == 1:\n",
+ " if n_colors is None:\n",
+ " n_colors = 1\n",
+ "\n",
+ " if ndim == 2:\n",
+ " colors = cm.Purples(np.linspace(0, 1, n_colors))\n",
+ " xlabel = \"Probability for Each Class\"\n",
+ "\n",
+ " axa.set_ylim(0, n_objects_a)\n",
+ " axa.set_xlabel(xlabel)\n",
+ " axa.set_title(title_a)\n",
+ "\n",
+ " for i in np.arange(n_classes):\n",
+ " axa.hist(y_prediction_a[:, i], alpha=alpha,\n",
+ " color=colors[i], label=\"'\" + str(i) + \"'\")\n",
+ "\n",
+ " if y_prediction_b is not None:\n",
+ " shape_b = np.shape(y_prediction_b)\n",
+ " axb.set_ylim(0, shape_b[0])\n",
+ " axb.set_xlabel(xlabel)\n",
+ " axb.set_title(title_b)\n",
+ "\n",
+ " for i in np.arange(n_classes):\n",
+ " axb.hist(y_prediction_b[:, i], alpha=alpha,\n",
+ " color=colors[i], label=\"'\" + str(i) + \"'\")\n",
+ "\n",
+ " if y_prediction_c is not None:\n",
+ " shape_c = np.shape(y_prediction_c)\n",
+ " axc.set_ylim(0, shape_c[0])\n",
+ " axc.set_xlabel(xlabel)\n",
+ " axc.set_title(title_c)\n",
+ "\n",
+ " for i in np.arange(n_classes):\n",
+ " axc.hist(y_prediction_c[:, i], alpha=alpha,\n",
+ " color=colors[i], label=\"'\" + str(i) + \"'\")\n",
+ "\n",
+ " elif ndim == 1:\n",
+ " ya, xa, _ = plt.hist(y_prediction_a, alpha=alpha, color='orange',\n",
+ " label=label_a)\n",
+ " y_max_list = [max(ya)]\n",
+ "\n",
+ " if y_prediction_b is not None:\n",
+ " yb, xb, _ = plt.hist(y_prediction_b, alpha=alpha, color='green',\n",
+ " label=label_b)\n",
+ " y_max_list.append(max(yb))\n",
+ "\n",
+ " if y_prediction_c is not None:\n",
+ " yc, xc, _ = plt.hist(y_prediction_c, alpha=alpha, color='purple',\n",
+ " label=label_c)\n",
+ " y_max_list.append(max(yc))\n",
+ "\n",
+ " plt.ylim(0, np.max(y_max_list)*1.1)\n",
+ " plt.xlabel(xlabel_plot)\n",
+ "\n",
+ " plt.legend(loc='upper right')\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:51.021678Z",
+ "iopub.status.busy": "2025-02-09T23:00:51.021279Z",
+ "iopub.status.idle": "2025-02-09T23:00:51.185198Z",
+ "shell.execute_reply": "2025-02-09T23:00:51.184585Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:51.021643Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def plotLossHistory(history, figsize=(8, 5),\n",
+ " save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot loss history of the model as function of epoch\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " history: `keras.src.callbacks.history.History`\n",
+ " keras callback history object containing the losses at each epoch\n",
+ " figsize: `tuple`, optional\n",
+ " figure size\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=figsize)\n",
+ "\n",
+ " loss_tra = np.array(history['loss'])\n",
+ " loss_val = np.array(history['val_loss'])\n",
+ " loss_dif = loss_val - loss_tra\n",
+ "\n",
+ " ax1.plot(loss_tra, label='Training')\n",
+ " ax1.plot(loss_val, label='Validation')\n",
+ " ax1.legend()\n",
+ "\n",
+ " ax2.plot(loss_dif, color='red', label='residual')\n",
+ " ax2.axhline(y=0, color='grey', linestyle='dashed', label='zero bias')\n",
+ " ax2.sharex(ax1)\n",
+ " ax2.legend()\n",
+ "\n",
+ " ax1.set_title('Loss History')\n",
+ " ax1.set_ylabel('Loss')\n",
+ " ax2.set_ylabel('Loss Residual')\n",
+ " ax2.set_xlabel('Epoch')\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ "\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:51.186307Z",
+ "iopub.status.busy": "2025-02-09T23:00:51.186054Z",
+ "iopub.status.idle": "2025-02-09T23:00:51.343743Z",
+ "shell.execute_reply": "2025-02-09T23:00:51.342674Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:51.186287Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def plotArrayImageConfusion(x_tra, y_tra, y_pred_tra_topchoice,\n",
+ " title_main=None, num=10,\n",
+ " save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot images of examples objects that are misclassified.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " x_tra: `numpy.ndarray`\n",
+ " training image data\n",
+ " y_tra: `numpy.ndarray`\n",
+ " training label data\n",
+ " y_pred_tra_topchoice: `numpy.ndarray`\n",
+ " top choice of the predicted labels\n",
+ " title_main: `string`, optional\n",
+ " title for the plot\n",
+ " num: `int`, optional\n",
+ " number of examples\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " num_row = 2\n",
+ " num_col = 5\n",
+ " images = x_tra[:num]\n",
+ " labels_true = y_tra[:num]\n",
+ " labels_pred = y_pred_tra_topchoice[:num]\n",
+ "\n",
+ " fig, axes = plt.subplots(num_row, num_col,\n",
+ " figsize=(1.5*num_col, 2*num_row))\n",
+ "\n",
+ " fig.patch.set_linewidth(5)\n",
+ " fig.patch.set_edgecolor('cornflowerblue')\n",
+ "\n",
+ " for i in range(num):\n",
+ " ax = axes[i//num_col, i%num_col]\n",
+ " ax.imshow(images[i], cmap='gray')\n",
+ " ax.set_title(r'True: {}'.format(labels_true[i]) + '\\n'\n",
+ " + 'Pred: {}'.format(labels_pred[i]))\n",
+ "\n",
+ " fig.suptitle(title_main)\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ "\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:51.348809Z",
+ "iopub.status.busy": "2025-02-09T23:00:51.348383Z",
+ "iopub.status.idle": "2025-02-09T23:00:51.591750Z",
+ "shell.execute_reply": "2025-02-09T23:00:51.591165Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:51.348771Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def plotArrayHistogramConfusion(x_tra, y_tra, y_pred_tra_topchoice,\n",
+ " title_main=None, num=10,\n",
+ " save_file=False,\n",
+ " file_prefix=\"prediction_histogram\",\n",
+ " file_location=\"./\",\n",
+ " file_suffix=\".png\"):\n",
+ " \"\"\"Plot histograms of pixel values for images that are misclassified.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " x_tra: `numpy.ndarray`\n",
+ " training image data\n",
+ " y_tra: `numpy.ndarray`\n",
+ " training label data\n",
+ " y_pred_tra_topchoice: `numpy.ndarray`\n",
+ " top choice of the predicted labels\n",
+ " title_main: `string`, optional\n",
+ " title of plot\n",
+ " num: `int`, optional\n",
+ " number of examples\n",
+ " file_prefix: `string`, optional\n",
+ " prefix of file name\n",
+ " file_location: `string`, optional\n",
+ " path to file\n",
+ " file_suffix: `string`, optional\n",
+ " suffix/extension of file name\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " None\n",
+ " \"\"\"\n",
+ " n_bins = 10\n",
+ " num_row = 2\n",
+ " num_col = 5\n",
+ " images = x_tra[:num]\n",
+ " labels_true = y_tra[:num]\n",
+ " labels_pred = y_pred_tra_topchoice[:num]\n",
+ "\n",
+ " fig, axes = plt.subplots(num_row, num_col,\n",
+ " figsize=(1.5*num_col, 2*num_row))\n",
+ "\n",
+ " fig.patch.set_linewidth(5)\n",
+ " fig.patch.set_edgecolor('cornflowerblue')\n",
+ "\n",
+ " for i in range(num):\n",
+ " ax = axes[i//num_col, i%num_col]\n",
+ " images_temp = images[i, :, :].flat\n",
+ " ax.hist(images_temp, bins=n_bins, color='gray')\n",
+ " ax.set_title(r'True: {}'.format(labels_true[i]) + '\\n'\n",
+ " + 'Pred: {}'.format(labels_pred[i]))\n",
+ " ax.set_xlabel('Pixel Values')\n",
+ "\n",
+ " fig.suptitle(title_main)\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " if save_file:\n",
+ " file_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=file_location,\n",
+ " file_suffix=file_suffix,\n",
+ " useuid=True)\n",
+ "\n",
+ " plt.savefig(file_final, bbox_inches='tight')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.4. Define paths for data and figures"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Neural network training (i.e., model fitting) typically requires many numerical experiments to achieve an ideal model. To facilitate the comparison of these experiments/models, it is helpful to organize data carefully. \n",
+ "\n",
+ "Set the variable `run_label` for each training run. \n",
+ "Set paths for the model weight parameters and diagnostic figures, and\n",
+ "save these paths in the dictionary `path_dict` to facilitate passing information to plotting functions.\n",
+ "Check whether the paths exist, and if not, create them."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:51.592497Z",
+ "iopub.status.busy": "2025-02-09T23:00:51.592305Z",
+ "iopub.status.idle": "2025-02-09T23:00:51.743517Z",
+ "shell.execute_reply": "2025-02-09T23:00:51.742419Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:51.592481Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "run_label = \"Run000\"\n",
+ "\n",
+ "path_temp = os.getenv(\"HOME\") + '/dp02_16a_temp/'\n",
+ "path_dict = {'run_label': run_label,\n",
+ " 'dir_data_model': path_temp + \"Data/Models/\",\n",
+ " 'dir_data_figures': path_temp + \"Data/Figures/\",\n",
+ " 'file_model_prefix': \"Model\",\n",
+ " 'file_figure_prefix': \"Figure\",\n",
+ " 'file_figure_suffix': \".png\",\n",
+ " 'file_model_suffix': \".pt\"\n",
+ " }\n",
+ "del path_temp\n",
+ "\n",
+ "if not os.path.exists(path_dict['dir_data_model']):\n",
+ " os.makedirs(path_dict['dir_data_model'])\n",
+ "\n",
+ "if not os.path.exists(path_dict['dir_data_figures']):\n",
+ " os.makedirs(path_dict['dir_data_figures'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 2. Load and Prepare data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.1 Data Set: MNIST Handwritten Digits"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The [`MNIST handwritten digits dataset`](https://ieeexplore.ieee.org/document/6296535) comprises 10 classes --- one for each digit. This is a useful dataset for learning the basics of neural networks and other AI algorithms. MNIST is one of a few canonical AI benchmark data sets for image classification. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.2. Obtain the dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "`pytorch` has a simple function to download the MNIST data to your local server for free. While downloading, we use the `transforms` method to normalize the data; this makes the model training more efficient."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 2.2.1 Normalize data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Create a `transform` that is used to convert the data sets to tensors that can be used in `pytorch`. This also transforms all the data from the range $[0,255]$ to $[0.,1.]$."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:51.745068Z",
+ "iopub.status.busy": "2025-02-09T23:00:51.744707Z",
+ "iopub.status.idle": "2025-02-09T23:00:51.884959Z",
+ "shell.execute_reply": "2025-02-09T23:00:51.884286Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:51.745032Z"
+ },
+ "id": "kMzao03zcF5i",
+ "outputId": "0401d0af-d607-436a-908f-385ffc85812c"
+ },
+ "outputs": [],
+ "source": [
+ "transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 2.2.2 Download data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Download the data from a remote reserver."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:51.886408Z",
+ "iopub.status.busy": "2025-02-09T23:00:51.885839Z",
+ "iopub.status.idle": "2025-02-09T23:00:52.170282Z",
+ "shell.execute_reply": "2025-02-09T23:00:52.169224Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:51.886378Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "dataset = torchvision.datasets.MNIST(root='./newdata', train=True,\n",
+ " download=True, transform=transform)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.3. Split data into training, validation, and testing"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Split for a proper 'blind' analysis and optimization of an AI model.\n",
+ "\n",
+ "There are three primary data sets used in model development and optimization:\n",
+ "\n",
+ "* `Training` (with filename tag `_tra`) data is used directly by the algorithm to update the parameters of the AI model -- e.g., the weights of the computational neurons on the edges in neural networks.\n",
+ "* `Validation` (`_val`) data is used indirectly to update the hyperparameters of the AI model -- e.g., the batch size (`batchsize`), the learning rate, or the layers in the architecture of a neural network. Each time the neural network has completed training with the training data, the human looks at those diagnostics when run on the training and the validation data.\n",
+ "* `Test(ing)` (`_tes`) data is only used when the model is trained and validated and will no longer be update or further trained. This is the data that you would consider to be the new data that has not been examined before -- e.g., newly observed data that has not bene previously characterized."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We use most of the data for training to maximize the accuracy and generalization of the model. We use a small amount of validation data, because only a little bit is needed to check the model optimization during training. We also use only a small amount of testing data, assuming that new data sets to examine are smaller than existing training data sets."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Set the fractions for the training, validation, and test sets."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:52.172134Z",
+ "iopub.status.busy": "2025-02-09T23:00:52.171340Z",
+ "iopub.status.idle": "2025-02-09T23:00:52.297172Z",
+ "shell.execute_reply": "2025-02-09T23:00:52.296634Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:52.172083Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "fraction_tra = 0.8\n",
+ "fraction_val = 0.1\n",
+ "fraction_tes = 0.1"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Split the data according to those fractions."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:52.297980Z",
+ "iopub.status.busy": "2025-02-09T23:00:52.297800Z",
+ "iopub.status.idle": "2025-02-09T23:00:52.444477Z",
+ "shell.execute_reply": "2025-02-09T23:00:52.443440Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:52.297963Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "fraction_list = [fraction_tra, fraction_val, fraction_tes]\n",
+ "\n",
+ "data_tra_full, data_val_full, data_tes_full = \\\n",
+ " torch.utils.data.random_split(dataset, fraction_list)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Print the data set sizes to make sure there's enough data for optimizing the model. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:52.445854Z",
+ "iopub.status.busy": "2025-02-09T23:00:52.445490Z",
+ "iopub.status.idle": "2025-02-09T23:00:52.586869Z",
+ "shell.execute_reply": "2025-02-09T23:00:52.586295Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:52.445819Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "training: 48000\n",
+ "validation: 6000\n",
+ "test 6000\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"training:\", len(data_tra_full.indices))\n",
+ "print(\"validation:\", len(data_val_full.indices))\n",
+ "print(\"test\", len(data_tes_full.indices))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.4. Create a subset of the training data."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use only a subset of each data to make the training faster for this tutorial. Consider increasing the sizes of these data sets in your exploration of the tutorial elements."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:52.587896Z",
+ "iopub.status.busy": "2025-02-09T23:00:52.587639Z",
+ "iopub.status.idle": "2025-02-09T23:00:52.723891Z",
+ "shell.execute_reply": "2025-02-09T23:00:52.722872Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:52.587879Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "subset_size = 5000\n",
+ "\n",
+ "subset_indices_tra = np.arange(subset_size)\n",
+ "subset_indices_val = np.arange(subset_size)\n",
+ "subset_indices_tes = np.arange(subset_size)\n",
+ "\n",
+ "data_tra = Subset(data_tra_full, subset_indices_tra)\n",
+ "data_val = Subset(data_val_full, subset_indices_val)\n",
+ "data_tes = Subset(data_tes_full, subset_indices_tes)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.4. Examine raw data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Review the raw data shapes by looking at a single datum in the training set. Each datum in the training set is an image-label pair. The image is a tensor, and the label is an integer. The image size in part determines the depth (number of layers) of the neural network. This will be discussed in a later section on model training."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:52.725316Z",
+ "iopub.status.busy": "2025-02-09T23:00:52.724942Z",
+ "iopub.status.idle": "2025-02-09T23:00:52.875460Z",
+ "shell.execute_reply": "2025-02-09T23:00:52.874383Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:52.725278Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The image shape is torch.Size([1, 28, 28]).\n",
+ "The label for this image is 6.\n"
+ ]
+ }
+ ],
+ "source": [
+ "sample_index = 0\n",
+ "image, label = data_tra[sample_index]\n",
+ "print(f\"The image shape is {image.shape}.\")\n",
+ "print(f\"The label for this image is {label}.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Print the image labels and their corresponding indices within a dataset object. This is useful for verifying that you undesrtand your data set."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:52.876528Z",
+ "iopub.status.busy": "2025-02-09T23:00:52.876301Z",
+ "iopub.status.idle": "2025-02-09T23:00:53.000692Z",
+ "shell.execute_reply": "2025-02-09T23:00:53.000075Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:52.876509Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{0: '0 - zero',\n",
+ " 1: '1 - one',\n",
+ " 2: '2 - two',\n",
+ " 3: '3 - three',\n",
+ " 4: '4 - four',\n",
+ " 5: '5 - five',\n",
+ " 6: '6 - six',\n",
+ " 7: '7 - seven',\n",
+ " 8: '8 - eight',\n",
+ " 9: '9 - nine'}"
+ ]
+ },
+ "execution_count": 23,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "{v: k for k, v in dataset.class_to_idx.items()}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot examples of the raw data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:53.001744Z",
+ "iopub.status.busy": "2025-02-09T23:00:53.001327Z",
+ "iopub.status.idle": "2025-02-09T23:00:53.833152Z",
+ "shell.execute_reply": "2025-02-09T23:00:53.832087Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:53.001716Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"Example_Image_Array\"\\\n",
+ " + \"_\" + path_dict['run_label']\n",
+ "\n",
+ "plotArrayImageExamples(data_tra,\n",
+ " num_row=4, num_col=3,\n",
+ " save_file=False,\n",
+ " object_index_start=0,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 1: Two rows of five images, each a handwritten number in white on a black background."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot the distributions of pixel values to understand the data further. All pixel data has been normalized, and pixels have values between 0 and 1 only.\n",
+ "\n",
+ "The distribution of pixel values matches the images shown above: mostly black (values near 0) pixels, with some white (values near 1), and a few grey (values in between 0 and 1)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:53.834798Z",
+ "iopub.status.busy": "2025-02-09T23:00:53.834404Z",
+ "iopub.status.idle": "2025-02-09T23:00:55.262207Z",
+ "shell.execute_reply": "2025-02-09T23:00:55.261170Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:53.834762Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"Example_Histogram_Array\"\\\n",
+ " + \"_\" + path_dict['run_label']\n",
+ "\n",
+ "plotArrayHistogramExamples(data_tra,\n",
+ " num_row=2, num_col=5,\n",
+ " save_file=False,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 2: Two rows of five plots, each showing the distribution of pixel values (number of pixels of a given value) for the handwritten digit images shown in Figure 1."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2024-12-15T20:26:20.181221Z",
+ "iopub.status.busy": "2024-12-15T20:26:20.180895Z",
+ "iopub.status.idle": "2024-12-15T20:26:20.186166Z",
+ "shell.execute_reply": "2024-12-15T20:26:20.185452Z",
+ "shell.execute_reply.started": "2024-12-15T20:26:20.181199Z"
+ }
+ },
+ "source": [
+ "Use the `plotPredictionHistogram` function to plot histograms of true label distributions by class."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:55.263752Z",
+ "iopub.status.busy": "2025-02-09T23:00:55.263376Z",
+ "iopub.status.idle": "2025-02-09T23:00:57.303748Z",
+ "shell.execute_reply": "2025-02-09T23:00:57.302743Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:55.263717Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "y_tra = []\n",
+ "y_val = []\n",
+ "y_tes = []\n",
+ "\n",
+ "for i in np.arange(subset_size):\n",
+ " image, label_tra = data_tra[i]\n",
+ " image, label_val = data_val[i]\n",
+ " image, label_tes = data_tes[i]\n",
+ " y_tra.append(label_tra)\n",
+ " y_val.append(label_val)\n",
+ " y_tes.append(label_tes)\n",
+ "\n",
+ "y_tra = np.array(y_tra)\n",
+ "y_val = np.array(y_val)\n",
+ "y_tes = np.array(y_tes)\n",
+ "\n",
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"Histograms_true_class\"\\\n",
+ " + \"_\" + path_dict['run_label']\n",
+ "\n",
+ "plotPredictionHistogram(y_tra,\n",
+ " y_prediction_b=y_val,\n",
+ " y_prediction_c=y_tes,\n",
+ " label_a=\"Training Set\",\n",
+ " label_b=\"Validation Set\",\n",
+ " label_c=\"Testing Set\",\n",
+ " figsize=(12, 5),\n",
+ " alpha=0.5,\n",
+ " xlabel_plot=\"True class label\",\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 3: The histograms of true class labels for each data set. Each histogram is for a different data set used during model training --- training data, validation data, and test data. Note that these are overlapping histograms, not stacked. Please compare to Figure 4, which shows histograms of the predicted class labels."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 2.5. Create Dataloaders for training"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Set the batch size, which will be used in dataloaders and in training. To simplify data-handling, we set the batch size to the size of the subset that we select. This means that there is one batch used in training."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:57.305293Z",
+ "iopub.status.busy": "2025-02-09T23:00:57.304947Z",
+ "iopub.status.idle": "2025-02-09T23:00:57.453279Z",
+ "shell.execute_reply": "2025-02-09T23:00:57.452205Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:57.305259Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "batch_size = subset_size"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Create dataloaders for the training, validation, and test set data loaders."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:57.454746Z",
+ "iopub.status.busy": "2025-02-09T23:00:57.454386Z",
+ "iopub.status.idle": "2025-02-09T23:00:57.589881Z",
+ "shell.execute_reply": "2025-02-09T23:00:57.589215Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:57.454711Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "trainloader = torch.utils.data.DataLoader(data_tra, batch_size=batch_size,\n",
+ " shuffle=False)\n",
+ "\n",
+ "validloader = torch.utils.data.DataLoader(data_val, batch_size=batch_size,\n",
+ " shuffle=False)\n",
+ "\n",
+ "testloader = torch.utils.data.DataLoader(data_tes, batch_size=batch_size,\n",
+ " shuffle=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 3. Train the model: Convolutional Neural Network"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2024-12-15T21:16:11.080821Z",
+ "iopub.status.busy": "2024-12-15T21:16:11.079938Z",
+ "iopub.status.idle": "2024-12-15T21:16:11.085835Z",
+ "shell.execute_reply": "2024-12-15T21:16:11.085097Z",
+ "shell.execute_reply.started": "2024-12-15T21:16:11.080794Z"
+ }
+ },
+ "source": [
+ "In `pytorch`, neural network models are defined as classes. This is slightly different than typical `tensorflow` usage, in which people build a `sequential` model or use a pre-built `model` class and add layers to that model. \n",
+ "\n",
+ "The other major difference between `pytorch` and `tensorflow` is the shape of the tensors and the inputs for each layer. \n",
+ "\n",
+ "In particular, in `pytorch` one has to explicitly match the output from one layer to the input of the next layer. Sometimes, this can be done with a calculation. But, more often, one must perform guess-and-check. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3.1. Define model training hyperparameters"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Set the seed that for the random initial model weights."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:57.590774Z",
+ "iopub.status.busy": "2025-02-09T23:00:57.590515Z",
+ "iopub.status.idle": "2025-02-09T23:00:57.739275Z",
+ "shell.execute_reply": "2025-02-09T23:00:57.738219Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:57.590755Z"
+ },
+ "id": "Zsc_pqZtWftJ",
+ "outputId": "8238aa09-2afd-47ef-c93e-3910b342fe38"
+ },
+ "outputs": [],
+ "source": [
+ "seed = 1729\n",
+ "new = torch.manual_seed(seed)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Define the loss function. We choose cross-entropy, which is the standard loss function for classification of discrete labels."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:57.741224Z",
+ "iopub.status.busy": "2025-02-09T23:00:57.740823Z",
+ "iopub.status.idle": "2025-02-09T23:00:57.892170Z",
+ "shell.execute_reply": "2025-02-09T23:00:57.891536Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:57.741190Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "loss_fn = nn.CrossEntropyLoss()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3.2. Define the neural network model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2.1. Define model architecture elements"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Below is a list of important terms for elements of an architecture:\n",
+ "\n",
+ " * `activation function`: a function within a neuron that takes inputs from a previous layer and produces an output -- usually, this function is non-linear. Examples include \"sigmoid,\" \"softmax,\" and \"reLu\" (rectified linear unit).\n",
+ " * `sigmoid`: an activation function that takes points from the Real line and maps them to the range [-1,1].\n",
+ " * `softmax`: an activation function that takes points from the Real line and maps them to the range [0,1]. This is often used to obtain a 'probability score.'\n",
+ " * `weights`: the weight factor within an activation function.\n",
+ " * `biases`: the bias factor applied after an activation function.\n",
+ " * `layer`: one set of nodes/neurons that receive input data simultaneously.\n",
+ " * `linear (Dense) layer`: occurs due to the \"flattening\" of a higher-dimensional data vector, like an image. It only has an activation function -- as opposed to a convolutional layer which makes a convolution operation.\n",
+ " * `convolutional layer`: a layer that applies a convolution operation to an input sample."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2.2. Implement model architecture elements in an object class"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Typically, there are two functions within this class.\n",
+ "\n",
+ "The constructor function (`__init__`) defines the available layers of the network. These layers require specific settings related to the data set shapes. \n",
+ "* `Conv2d`: defines a two-dimensional convolutional layer. Four inputs are considered here:\n",
+ " * `in_channels` (required): the number of input channels.\n",
+ " * `out_channels` (required): the number of output channels.\n",
+ " * `kernel_size` (required): the size on one dimension of the convolutional kernel.\n",
+ " * `stride` (optional): the stride of the convolution\n",
+ "* `Dropout`: defines a dropout layer. The fraction of neuron weights that are set to zero. Typically, this is used only\n",
+ "* `Linear`: defines a linear layer. Two inputs are considered here:\n",
+ " * `in_features` (required): the size of the sample input to the layer.\n",
+ " * `out_features` (required): the size of the sample output from the layer.\n",
+ "\n",
+ "The function `forward` defines the order of operations during a forward pass of the model. During training, the `forward` function is applied to the input data to make predictions. After each round of predictions (epoch), the optimizer is engaged to take the difference between the true labels and the predicted labels and then use that difference to update the model weights. \n",
+ "\n",
+ "The `forward` function uses the layers defined in the constructor, as well as other layers that don't require inputs that depend on the data. These layers are defined in the `torch.nn.functional` submodule, which contains predefined functions for layers that operate directly on the data and don't require an instance of that layer. These `functional` layers are \n",
+ "* `relu`: applies the activation function, the `Rectified Linear Unit`. It requires one input, the sample from the previous layer.\n",
+ "* `max_pool2d`: the max pooling function is applied to the sample from the previous layer. It requires the input sample and the size of the kernel of the pooling.\n",
+ "* `flatten`: reshapes the sample input into a one-dimensional tensor."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2.3. Define the object class that represents the model. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:57.893245Z",
+ "iopub.status.busy": "2025-02-09T23:00:57.892802Z",
+ "iopub.status.idle": "2025-02-09T23:00:58.065150Z",
+ "shell.execute_reply": "2025-02-09T23:00:58.064028Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:57.893224Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "class ConvNet(nn.Module):\n",
+ " def __init__(self):\n",
+ " super(ConvNet, self).__init__()\n",
+ " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
+ " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
+ " self.dropout1 = nn.Dropout(0.25)\n",
+ " self.dropout2 = nn.Dropout(0.5)\n",
+ " self.fc1 = nn.Linear(9216, 128)\n",
+ " self.fc2 = nn.Linear(128, 10)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x = self.conv1(x)\n",
+ " x = F.relu(x)\n",
+ " x = self.conv2(x)\n",
+ " x = F.relu(x)\n",
+ " x = F.max_pool2d(x, 2)\n",
+ " x = self.dropout1(x)\n",
+ " x = torch.flatten(x, 1)\n",
+ " x = self.fc1(x)\n",
+ " x = F.relu(x)\n",
+ " x = self.dropout2(x)\n",
+ " x = self.fc2(x)\n",
+ " output = x\n",
+ " return output"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Instantiate a neural network model object."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:58.066569Z",
+ "iopub.status.busy": "2025-02-09T23:00:58.066197Z",
+ "iopub.status.idle": "2025-02-09T23:00:58.205918Z",
+ "shell.execute_reply": "2025-02-09T23:00:58.204825Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:58.066533Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "model = ConvNet()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Define the device. This is important especially when a GPU is available. This is necessary because the model and the data get moved to that device.\n",
+ "\n",
+ "In our case, the device is a CPU because that's what is currently available on the RSP."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:58.207635Z",
+ "iopub.status.busy": "2025-02-09T23:00:58.207117Z",
+ "iopub.status.idle": "2025-02-09T23:00:58.346979Z",
+ "shell.execute_reply": "2025-02-09T23:00:58.345787Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:58.207575Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "device = torch.device('cuda') if torch.cuda.is_available()\\\n",
+ " else torch.device('cpu')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Put the model on the device where the computations will be performed.\n",
+ "\n",
+ "When placing the model on the device, it also shows a summary of the network architecture. Examine the shapes of the layers and the numbers of parameters. Too few parameters may prevent the model from being flexible enough to model the data. Too many parameters could lead to overfitting of the model and a high computational cost."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:58.348374Z",
+ "iopub.status.busy": "2025-02-09T23:00:58.348049Z",
+ "iopub.status.idle": "2025-02-09T23:00:58.484854Z",
+ "shell.execute_reply": "2025-02-09T23:00:58.484286Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:58.348343Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Model Summary:\n",
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "ConvNet(\n",
+ " (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n",
+ " (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n",
+ " (dropout1): Dropout(p=0.25, inplace=False)\n",
+ " (dropout2): Dropout(p=0.5, inplace=False)\n",
+ " (fc1): Linear(in_features=9216, out_features=128, bias=True)\n",
+ " (fc2): Linear(in_features=128, out_features=10, bias=True)\n",
+ ")"
+ ]
+ },
+ "execution_count": 34,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "print(\"Model Summary:\\n\")\n",
+ "model.to(device)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2.4. Define hyperparameter terms for the training schedule"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "* `learning rate`: A multiplicative factor defining the amount that the model weights will change in response to the size of the error (loss) in the model prediction for that epoch. The lower the learning rate, the smaller the change to the weights, and usually the longer it will take to train the model. However, if the learning rate is too high, the weights can change too quickly: then, the loss can fluctuate significantly and not decrease quickly.\n",
+ "* `momentum`: A multiplicative factor on the aggregate of previous gradients. This aggregate term is combined with the weight gradient term to define the total change in the value of the weights. The smaller the momentum, the lesser the influence of the aggregated gradients, and usually the longer it will take train the model.\n",
+ "* `optimizer`: The method/algorithm used to update the network weights. Stochastic Gradient Descent (SGD) is the most commonly used method.\n",
+ "* `epoch`: One loop of training the model. Each loop includes the entire data set (all the batches) once, and it includes at least one round of weight updates.\n",
+ "* `n_epochs`: The number of epochs to train the network."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 3.2.5. Assign hyperparameters for the training schedule"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:58.485804Z",
+ "iopub.status.busy": "2025-02-09T23:00:58.485463Z",
+ "iopub.status.idle": "2025-02-09T23:00:58.610914Z",
+ "shell.execute_reply": "2025-02-09T23:00:58.609882Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:58.485785Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "learning_rate = 0.01\n",
+ "momentum = 0.9\n",
+ "optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)\n",
+ "n_epochs = 50"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3.4. Train the model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Set the model into a training context. This ensures that layers like \"batchnorm\" and \"dropout\" will be activated. In contrast, when the model is set to an evaluation context (later in this tutorial), those layers will be deactivated."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:58.612344Z",
+ "iopub.status.busy": "2025-02-09T23:00:58.611972Z",
+ "iopub.status.idle": "2025-02-09T23:00:58.756735Z",
+ "shell.execute_reply": "2025-02-09T23:00:58.755676Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:58.612300Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "model.train()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use a loop over epochs, optimize the network weight parameters. \n",
+ "\n",
+ "Use lists to track the loss values on the training data, the loss values on the testing data, and the accuracy values on the testing data. Define the \"history\" dictionary to hold those lists. We will visualize these later to study the fitting efficacy and generalization capacity of the network."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:00:58.758266Z",
+ "iopub.status.busy": "2025-02-09T23:00:58.757903Z",
+ "iopub.status.idle": "2025-02-09T23:13:23.377226Z",
+ "shell.execute_reply": "2025-02-09T23:13:23.376192Z",
+ "shell.execute_reply.started": "2025-02-09T23:00:58.758217Z"
+ },
+ "id": "zB1OY3o8VWF_",
+ "outputId": "7bdd7e85-cef2-47af-f16b-5f8bbd66f67f"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch ( 0): accuracy (9.78 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 1): accuracy (10.26 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 2): accuracy (10.44 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 3): accuracy (11.30 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 4): accuracy (12.58 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 5): accuracy (14.20 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 6): accuracy (16.96 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 7): accuracy (18.42 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 8): accuracy (19.68 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 9): accuracy (21.68 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 10): accuracy (23.16 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 11): accuracy (24.86 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 12): accuracy (24.88 %), train loss (0.0005), valid loss (0.0005) \n",
+ "Epoch ( 13): accuracy (26.24 %), train loss (0.0005), valid loss (0.0004) \n",
+ "Epoch ( 14): accuracy (28.46 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 15): accuracy (28.84 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 16): accuracy (29.86 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 17): accuracy (32.38 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 18): accuracy (33.24 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 19): accuracy (36.54 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 20): accuracy (38.86 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 21): accuracy (39.54 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 22): accuracy (42.22 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 23): accuracy (44.86 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 24): accuracy (48.12 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 25): accuracy (49.74 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 26): accuracy (53.24 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 27): accuracy (53.98 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 28): accuracy (56.04 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 29): accuracy (56.62 %), train loss (0.0004), valid loss (0.0004) \n",
+ "Epoch ( 30): accuracy (59.60 %), train loss (0.0004), valid loss (0.0003) \n",
+ "Epoch ( 31): accuracy (60.46 %), train loss (0.0003), valid loss (0.0003) \n",
+ "Epoch ( 32): accuracy (59.72 %), train loss (0.0003), valid loss (0.0003) \n",
+ "Epoch ( 33): accuracy (61.94 %), train loss (0.0003), valid loss (0.0003) \n",
+ "Epoch ( 34): accuracy (63.48 %), train loss (0.0003), valid loss (0.0003) \n",
+ "Epoch ( 35): accuracy (64.74 %), train loss (0.0003), valid loss (0.0002) \n",
+ "Epoch ( 36): accuracy (66.14 %), train loss (0.0002), valid loss (0.0002) \n",
+ "Epoch ( 37): accuracy (67.70 %), train loss (0.0002), valid loss (0.0002) \n",
+ "Epoch ( 38): accuracy (68.60 %), train loss (0.0002), valid loss (0.0002) \n",
+ "Epoch ( 39): accuracy (69.16 %), train loss (0.0002), valid loss (0.0002) \n",
+ "Epoch ( 40): accuracy (71.06 %), train loss (0.0002), valid loss (0.0002) \n",
+ "Epoch ( 41): accuracy (72.12 %), train loss (0.0002), valid loss (0.0002) \n",
+ "Epoch ( 42): accuracy (72.66 %), train loss (0.0002), valid loss (0.0002) \n",
+ "Epoch ( 43): accuracy (74.16 %), train loss (0.0002), valid loss (0.0002) \n",
+ "Epoch ( 44): accuracy (75.68 %), train loss (0.0002), valid loss (0.0001) \n",
+ "Epoch ( 45): accuracy (77.04 %), train loss (0.0002), valid loss (0.0001) \n",
+ "Epoch ( 46): accuracy (77.06 %), train loss (0.0001), valid loss (0.0001) \n",
+ "Epoch ( 47): accuracy (78.44 %), train loss (0.0001), valid loss (0.0001) \n",
+ "Epoch ( 48): accuracy (79.26 %), train loss (0.0001), valid loss (0.0001) \n",
+ "Epoch ( 49): accuracy (79.56 %), train loss (0.0001), valid loss (0.0001) \n",
+ "Total training time: 744.4417\n"
+ ]
+ }
+ ],
+ "source": [
+ "time_start = time.time()\n",
+ "\n",
+ "loss_train_list = []\n",
+ "loss_test_list = []\n",
+ "accuracy_test_list = []\n",
+ "\n",
+ "for epoch in np.arange(n_epochs):\n",
+ "\n",
+ " loss_train = 0\n",
+ " for inputs, labels in trainloader:\n",
+ " inputs = inputs.to(device)\n",
+ " labels = labels.to(device)\n",
+ " y_pred = model(inputs)\n",
+ " loss = loss_fn(y_pred, labels)\n",
+ " optimizer.zero_grad()\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ " loss_train += loss.item()\n",
+ "\n",
+ " loss_train /= batch_size\n",
+ "\n",
+ " accuracy_test = 0\n",
+ " loss_test = 0\n",
+ " count_test = 0\n",
+ " for inputs, labels in testloader:\n",
+ " inputs = inputs.to(device)\n",
+ " labels = labels.to(device)\n",
+ " y_pred = model(inputs)\n",
+ " loss = loss_fn(y_pred, labels)\n",
+ " accuracy_test += (torch.argmax(y_pred, 1) == labels).float().sum()\n",
+ " loss_test += loss.item()\n",
+ " count_test += len(labels)\n",
+ "\n",
+ " accuracy_test /= count_test\n",
+ " loss_test /= batch_size\n",
+ "\n",
+ " loss_train_list.append(loss_train)\n",
+ " loss_test_list.append(loss_test)\n",
+ " accuracy_test_list.append(accuracy_test)\n",
+ "\n",
+ " output = f\"Epoch ({epoch:3d}): accuracy ({accuracy_test*100:.2f} %),\\\n",
+ " train loss ({loss_train:.4f}), valid loss ({loss_test:.4f}) \"\n",
+ " print(output)\n",
+ "\n",
+ "time_end = time.time()\n",
+ "\n",
+ "time_difference = time_end - time_start\n",
+ "\n",
+ "history = {\"loss\": loss_train_list,\n",
+ " \"val_loss\": loss_test_list,\n",
+ " \"accuracy_test\": accuracy_test_list}\n",
+ "\n",
+ "print(f\"Total training time: {time_difference:2.4f}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Save model to file: \"pt\" is the common suffix used for pytorch model files. This saves the architecture and weights of the model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:23.378889Z",
+ "iopub.status.busy": "2025-02-09T23:13:23.378481Z",
+ "iopub.status.idle": "2025-02-09T23:13:23.551547Z",
+ "shell.execute_reply": "2025-02-09T23:13:23.550467Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:23.378864Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/home/bnord/dp02_16a_temp/Data/Models/Model_Run000_20250209_231323.pt\n"
+ ]
+ }
+ ],
+ "source": [
+ "file_prefix = path_dict['file_model_prefix'] + \"_\" + path_dict['run_label']\n",
+ "file_name_final = createFileName(file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_model'],\n",
+ " file_suffix=path_dict['file_model_suffix'],\n",
+ " useuid=True,\n",
+ " verbose=True)\n",
+ "\n",
+ "torch.save(model.state_dict(), file_name_final)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Load the model from a \"pt\" file."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:23.553120Z",
+ "iopub.status.busy": "2025-02-09T23:13:23.552767Z",
+ "iopub.status.idle": "2025-02-09T23:13:23.711074Z",
+ "shell.execute_reply": "2025-02-09T23:13:23.709943Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:23.553086Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "model.load_state_dict(torch.load(file_name_final, weights_only=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-01-14T23:05:09.388295Z",
+ "iopub.status.busy": "2025-01-14T23:05:09.387717Z",
+ "iopub.status.idle": "2025-01-14T23:05:09.481703Z",
+ "shell.execute_reply": "2025-01-14T23:05:09.481106Z",
+ "shell.execute_reply.started": "2025-01-14T23:05:09.388255Z"
+ }
+ },
+ "source": [
+ "Set the model to evaluation mode so that the \"batchnorm\" and \"dropout\" layers are deactivated. This is necessary for consistent inference."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:23.712652Z",
+ "iopub.status.busy": "2025-02-09T23:13:23.712268Z",
+ "iopub.status.idle": "2025-02-09T23:13:24.006950Z",
+ "shell.execute_reply": "2025-02-09T23:13:24.005916Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:23.712613Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "model.eval()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 4. Diagnosing the Results of Model Training"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.1. Key terms for diagnostics and model evaluation"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use the following diagnostics to assess the status of the network optimization and efficacy. The [scikit-learn page on metrics and scoring](https://scikit-learn.org/stable/modules/model_evaluation.html) provides a good in-depth reference for these terms.\n",
+ "\n",
+ "Model Predictions:\n",
+ " * `Classification threshold`: The user-chosen value $[0,1]$ that sets the threshold for a positive classification. \n",
+ " * `Probability score`: The output from the classifier neural network. One typically uses the `softmax` activation function in the last layer of the NN to provide an output in the range $[0,1]$.\n",
+ " * `Classification score`: The predicted class label is the class that received the highest probability score.\n",
+ "\n",
+ "Metrics:\n",
+ " * `Loss`: A function of the difference between the true labels and predicted labels.\n",
+ " * `Accuracy`: A rough indicator of model training progress/convergence for balanced datasets. For model performance, use only in combination with other metrics. Avoid this metric when you have unbalanced training datasets. Consider using another metric.\n",
+ " * `True Positive Rate (TPR; \"Recall\")`: Use when false negatives are more expensive than false positives.\n",
+ " * `False Positive Rate (FPR)`: Use when false positives are more expensive than false negatives.\n",
+ " * `Precision`: Use when positive predictions need to be accurate.\n",
+ "\n",
+ "The `Generalization Error` (GE) is the difference in loss when the model is applied to training data versus when applied to validation and test data.\n",
+ "\n",
+ "The `Confusion Matrix` is a visual representation of the classification accuracy. Each row is the set of predictions for each true value (with one true value per column). Values along the diagonal indicate true positives (correct predictions). Values below the diagonal indicate false positives. Values above the diagonal indicate false negatives. The optimal scenario is one in which the off-diagonal values are all zero.\n",
+ "\n",
+ "The `Receiver Operator Characteristic (ROC) Curve` curve presents a comparison between the true positive rate (y-axis) and the false positive rate (x-axis) --- for a given false positive rate, the number of true positives that exist. Each point on the curve is for a distinct choice of the Classification threshold for the probability score $[0,1]$: the choice of classification threshold determines which objects are considered correctly classified. The optimal scenario is where the ROC curve is constant at a true positive rate $=1$. If the curve is along the diagonal (lower left to upper right), it indicates that the model performance is equivalent to 50-50 guessing.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.2. Predict classifications with the trained model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Predict classification probabilities on the training, validation, and test sets. Produce both the probabilities of each digit and the top choice for each prediction."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:24.013950Z",
+ "iopub.status.busy": "2025-02-09T23:13:24.012874Z",
+ "iopub.status.idle": "2025-02-09T23:13:36.594110Z",
+ "shell.execute_reply": "2025-02-09T23:13:36.593536Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:24.013906Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "test data set ... Accuracy: 86.4%, Avg loss: 0.467137\n",
+ "validation data set ... Accuracy: 86.2%, Avg loss: 0.487235\n",
+ "training data set ... Accuracy: 86.0%, Avg loss: 0.487080\n"
+ ]
+ }
+ ],
+ "source": [
+ "y_prob_tes, y_choice_tes, y_true_tes, x_tes = predict(testloader, model,\n",
+ " \"test\")\n",
+ "y_prob_val, y_choice_val, y_true_val, x_val = predict(validloader, model,\n",
+ " \"validation\")\n",
+ "y_prob_tra, y_choice_tra, y_true_tra, x_tra = predict(trainloader, model,\n",
+ " \"training\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Print the shapes and verify that the shape of `y_prob_tes` matches the length of the input data `x_tes` and the number of classes. (as in Section 2.4)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:36.595037Z",
+ "iopub.status.busy": "2025-02-09T23:13:36.594822Z",
+ "iopub.status.idle": "2025-02-09T23:13:36.743092Z",
+ "shell.execute_reply": "2025-02-09T23:13:36.742096Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:36.595018Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The input data has the shape (5000, 28, 28): there the 5000 images each image has 28 pixels on a side.\n",
+ "The predicted probability score array has the shape (5000, 10): there are 5000 predictions, with 10 probability scores predicted for each input image.\n",
+ "The predicted classes array has has the shape (5000,): there is one top choice (highest probability score) for each prediction.\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(f\"The input data has the shape {np.shape(x_tes)}: there the {subset_size}\\\n",
+ " images each image has 28 pixels on a side.\")\n",
+ "print(f\"The predicted probability score array has the shape\\\n",
+ " {np.shape(y_prob_tes)}: there are {subset_size} predictions,\\\n",
+ " with 10 probability scores predicted for each input image.\")\n",
+ "print(f\"The predicted classes array has has the shape {np.shape(y_choice_tes)}:\\\n",
+ " there is one top choice (highest probability score) for each prediction.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use the `plotPredictionHistogram` function to plot histograms of prediction distributions by class."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 43,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:36.744923Z",
+ "iopub.status.busy": "2025-02-09T23:13:36.744545Z",
+ "iopub.status.idle": "2025-02-09T23:13:37.112883Z",
+ "shell.execute_reply": "2025-02-09T23:13:37.111880Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:36.744887Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAA9oAAAHACAYAAABdxRCTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABIR0lEQVR4nO3de1xVdb7/8feWy+YioKCwJVEhMW/kBY1ESxxQxvI2Vpq3MNFsUAsvo5lmaAlpIzrlxGjHEdMxbSo7aTfRAnU4lTFaakZWeGmEsA4HvCAg7N8fPty/trdAF27E1/PxWI+H+7s+a63P2u055/Hmuy4mq9VqFQAAAAAAMEQDRzcAAAAAAEB9QtAGAAAAAMBABG0AAAAAAAxE0AYAAAAAwEAEbQAAAAAADETQBgAAAADAQARtAAAAAAAMRNAGAAAAAMBAzo5u4FpUVVXp+PHj8vLykslkcnQ7AAAAAIB6zmq16uTJkwoMDFSDBlefs74pg/bx48cVFBTk6DYAAAAAALeYY8eOqXnz5letuSmDtpeXl6TzJ+jt7e3gbgAAAAAA9V1JSYmCgoJsefRqbsqgfeFycW9vb4I2AAAAAOCGqc7tyzwMDQAAAAAAAxG0AQAAAAAwEEEbAAAAAAAD3ZT3aAMAAADA5VitVp07d06VlZWObgU3IRcXFzk5OV33fgjaAAAAAOqF8vJy5efn68yZM45uBTcpk8mk5s2bq2HDhte1H4I2AAAAgJteVVWV8vLy5OTkpMDAQLm6ulbr6dDABVarVSdOnNCPP/6o0NDQ65rZJmgDAAAAuOmVl5erqqpKQUFB8vDwcHQ7uEk1bdpUhw8fVkVFxXUFbR6GBgAAAKDeaNCAiINrZ9RVEPwKAQAAAAAwEEEbAAAAAOqZqKgoJSYmVrv+8OHDMplM2rt3b631dCshaAMAAACAg5hMpqsuY8eOvab9vv3223ruueeqXR8UFKT8/Hx17Njxmo5XE2+99ZYiIiLk4+MjLy8vdejQQdOnT6/RPkwmk955553aadAAPAwNAAAAQP32VdKNPd6d1T9efn6+7d8bN27UvHnzlJubaxtzd3e3q6+oqJCLi8tv7tfX17faPUiSk5OTLBZLjba5Ftu2bdPDDz+s5ORkDRo0SCaTSV9//bW2b99e68e+kZjRBgAAAAAHsVgstsXHx0cmk8n2+ezZs2rUqJHeeOMNRUVFyc3NTevWrdMvv/yiESNGqHnz5vLw8FBYWJhef/11u/1efOl4q1atlJycrHHjxsnLy0stWrTQypUrbesvvnQ8MzNTJpNJ27dvV7du3eTh4aHIyEi7PwJI0vPPPy9/f395eXlp/Pjxeuqpp9S5c+crnu+WLVvUq1cv/elPf9Idd9yhNm3aaMiQIXr55Zft6jZv3qzw8HC5ubkpJCRE8+fP17lz52znIkl/+MMfZDKZbJ/rEoI2AAAAANRhs2bN0hNPPKGDBw8qNjZWZ8+eVXh4uLZs2aL9+/frscce05gxY/TZZ59ddT9LlixRt27dtGfPHiUkJOiPf/yjvvnmm6tuM2fOHC1ZskRffPGFnJ2dNW7cONu6f/zjH1q4cKEWLVqknJwctWjRQmlpaVfdn8Vi0YEDB7R///4r1nz00UcaPXq0nnjiCX399ddasWKF0tPTtXDhQknS7t27JUmrV69Wfn6+7XNdQtAGAAAAgDosMTFRQ4cOVXBwsAIDA3XbbbdpxowZ6ty5s0JCQjRlyhTFxsbqn//851X3c9999ykhIUGtW7fWrFmz1KRJE2VmZl51m4ULF6p3795q3769nnrqKWVnZ+vs2bOSpJdfflnx8fF69NFH1aZNG82bN09hYWFX3d+UKVPUvXt3hYWFqVWrVnr44Yf197//XWVlZXbHfOqppxQXF6eQkBD17dtXzz33nFasWCHp/LuuJalRo0ayWCy2z3UJQRsAAAAA6rBu3brZfa6srNTChQt15513ys/PTw0bNtTWrVt19OjRq+7nzjvvtP37wiXqhYWF1d6mWbNmkmTbJjc3V3fddZdd/cWfL+bp6an33ntP3333nebOnauGDRtq+vTpuuuuu3TmzBlJUk5OjhYsWKCGDRvalgkTJig/P99WU9fxMDQAAAAAqMM8PT3tPi9ZskRLly7VsmXLFBYWJk9PTyUmJqq8vPyq+7n4IWomk0lVVVXV3sZkMkmS3TYXxi6wWq1X3d8Ft99+u26//XaNHz9ec+bMUZs2bbRx40Y9+uijqqqq0vz58zV06NBLtnNzc6vW/h2NoA0AAAAAN5GdO3dq8ODBGj16tKTzwffQoUNq167dDe3jjjvu0Oeff64xY8bYxr744osa76dVq1by8PDQ6dOnJUldu3ZVbm6uWrdufcVtXFxcVFlZWfOmb5AaXTreqlWry77bbdKkSZLO//UiKSlJgYGBcnd3V1RUlA4cOGC3j7KyMk2ZMkVNmjSRp6enBg0apB9//NG4MwIAAACAeqx169bKyMhQdna2Dh48qIkTJ6qgoOCG9zFlyhStWrVKa9as0aFDh/T888/rq6++umSW+9eSkpI0c+ZMZWZmKi8vT3v27NG4ceNUUVGhvn37SpLmzZun1157TUlJSTpw4IAOHjyojRs3au7cubb9tGrVStu3b1dBQYGKiopq/VxrqkZBe/fu3crPz7ctGRkZkqSHHnpIkrR48WKlpqZq+fLl2r17tywWi/r27auTJ0/a9pGYmKhNmzZpw4YN2rVrl06dOqUBAwbU6b9GAAAAAEBd8cwzz6hr166KjY1VVFSULBaLhgwZcsP7GDVqlGbPnq0ZM2aoa9euysvL09ixY696eXfv3r31ww8/6JFHHlHbtm3Vv39/FRQUaOvWrbrjjjskSbGxsdqyZYsyMjLUvXt33X333UpNTVXLli1t+1myZIkyMjIUFBSkLl261Pq51pTJWt2L6C8jMTFRW7Zs0aFDhyRJgYGBSkxM1KxZsySdn70OCAjQokWLNHHiRBUXF6tp06Zau3athg8fLkk6fvy4goKC9P777ys2NrZaxy0pKZGPj4+Ki4vl7e19re0DAAAAqCfOnj2rvLw8BQcH3zT38dZHffv2lcVi0dq1ax3dyjW52u+oJjn0mp86Xl5ernXr1mncuHEymUzKy8tTQUGB+vXrZ6sxm83q3bu3srOzJZ1/elxFRYVdTWBgoDp27GiruZyysjKVlJTYLQAAAAAAxzlz5oxSU1N14MABffPNN3r22We1bds2xcXFObo1h7vmoP3OO+/o//7v/zR27FhJst0TEBAQYFcXEBBgW1dQUCBXV1c1btz4ijWXk5KSIh8fH9sSFBR0rW0DAAAAAAxgMpn0/vvv65577lF4eLg2b96st956SzExMY5uzeGu+anjq1atUv/+/RUYGGg3frnHu1/tZvjq1MyePVvTpk2zfS4pKSFsAwAAAIADubu7a9u2bY5uo066phntI0eOaNu2bRo/frxtzGKxSNIlM9OFhYW2WW6LxaLy8vJLngr365rLMZvN8vb2tlsAAAAAAKiLrilor169Wv7+/rr//vttY8HBwbJYLLYnkUvn7+POyspSZGSkJCk8PFwuLi52Nfn5+dq/f7+tBgAAAACAm1mNLx2vqqrS6tWrFRcXJ2fn/7+5yWRSYmKikpOTFRoaqtDQUCUnJ8vDw0MjR46UJPn4+Cg+Pl7Tp0+Xn5+ffH19NWPGDIWFhXEdPwAAAACgXqhx0N62bZuOHj2qcePGXbJu5syZKi0tVUJCgoqKihQREaGtW7fKy8vLVrN06VI5Oztr2LBhKi0tVXR0tNLT0+Xk5HR9ZwIAAAAAQB1wXe/RdhTeow0AAADg13iPNozg8PdoAwAAAACASxG0AQAAAOAmFxUVpcTERNvnVq1aadmyZVfdxmQy6Z133rnuYxu1n/rkmt+jDQAAAAA3g6TMpBt7vKjqH2/gwIEqLS297Puo/+d//keRkZHKyclR165da9TD7t275enpWaNtfktSUpLeeecd7d271248Pz9fjRs3NvRYF6usrNTixYu1Zs0aHTlyRO7u7mrTpo0mTpyoRx99tFr7yMzMVJ8+fVRUVKRGjRrVar8Ebdx0bvT/obwV1OT/GQAAAMA48fHxGjp0qI4cOaKWLVvarfv73/+uzp071zhkS1LTpk2NavE3WSyWWj9GUlKSVq5cqeXLl6tbt24qKSnRF198oaKiolo/9rXg0nEAAAAAcJABAwbI399f6enpduNnzpzRxo0bFR8fr19++UUjRoxQ8+bN5eHhobCwML3++utX3e/Fl44fOnRI9957r9zc3NS+fXtlZGRcss2sWbPUpk0beXh4KCQkRM8884wqKiokSenp6Zo/f76+/PJLmUwmmUwmW88XXzq+b98+/e53v5O7u7v8/Pz02GOP6dSpU7b1Y8eO1ZAhQ/TnP/9ZzZo1k5+fnyZNmmQ71uVs3rxZCQkJeuihhxQcHKxOnTopPj5e06ZNs9VYrVYtXrxYISEhcnd3V6dOnfTmm29Kkg4fPqw+ffpIkho3biyTyaSxY8de9Tu8HsxoAwAAAICDODs765FHHlF6errmzZsnk8kkSfrnP/+p8vJyjRo1SmfOnFF4eLhmzZolb29vvffeexozZoxCQkIUERHxm8eoqqrS0KFD1aRJE3366acqKSmxu5/7Ai8vL6WnpyswMFD79u3ThAkT5OXlpZkzZ2r48OHav3+/PvzwQ9tl7j4+Ppfs48yZM/r973+vu+++W7t371ZhYaHGjx+vyZMn2/0x4ZNPPlGzZs30ySef6LvvvtPw4cPVuXNnTZgw4bLnYLFY9PHHHyshIeGKs/Vz587V22+/rbS0NIWGhmrHjh0aPXq0mjZtql69eumtt97SAw88oNzcXHl7e8vd3f03v7trRdAGAAAAAAcaN26cXnzxRds9xNL5y8aHDh2qxo0bq3HjxpoxY4atfsqUKfrwww/1z3/+s1pBe9u2bTp48KAOHz6s5s2bS5KSk5PVv39/u7q5c+fa/t2qVStNnz5dGzdu1MyZM+Xu7q6GDRvK2dn5qpeK/+Mf/1Bpaalee+012z3iy5cv18CBA7Vo0SIFBARIOj+rvHz5cjk5Oalt27a6//77tX379isG7dTUVD344IOyWCzq0KGDIiMjNXjwYNs5nD59Wqmpqfr444/Vo0cPSVJISIh27dqlFStWqHfv3vL19ZUk+fv7c482AAAAANRnbdu2VWRkpP7+97+rT58++v7777Vz505t3bpV0vkHgb3wwgvauHGj/vOf/6isrExlZWXVftjZwYMH1aJFC1vIlmQLo7/25ptvatmyZfruu+906tQpnTt37jffF325Y3Xq1Mmut549e6qqqkq5ubm2oN2hQwc5OTnZapo1a6Z9+/Zdcb/t27fX/v37lZOTo127dmnHjh0aOHCgxo4dq//6r//S119/rbNnz6pv375225WXl6tLly41OgcjcI82AAAAADhYfHy83nrrLZWUlGj16tVq2bKloqOjJUlLlizR0qVLNXPmTH388cfau3evYmNjVV5eXq19W63WS8YuXKJ+waeffqqHH35Y/fv315YtW7Rnzx7NmTOn2sf49bEu3vfljuni4nLJuqqqqqvuu0GDBurevbumTp2qTZs2KT09XatWrVJeXp5t2/fee0979+61LV9//bXtPu0biRltAAAAAHCwYcOG6cknn9T69eu1Zs0aTZgwwRZMd+7cqcGDB2v06NGSzt9zfejQIbVr165a+27fvr2OHj2q48ePKzAwUNL5V4f92r/+9S+1bNlSc+bMsY0dOXLErsbV1VWVlZW/eaw1a9bo9OnTtlntf/3rX2rQoIHatGlTrX6rq3379pLOXzbevn17mc1mHT16VL17975svaurqyT95jkYgRltAAAAAHCwhg0bavjw4Xr66ad1/Phxuydit27dWhkZGcrOztbBgwc1ceJEFRQUVHvfMTExuuOOO/TII4/oyy+/1M6dO+0C9YVjHD16VBs2bND333+vl156SZs2bbKradWqlfLy8rR37179/PPPKisru+RYo0aNkpubm+Li4rR//3598sknmjJlisaMGWO7bPxaPPjgg1q6dKk+++wzHTlyRJmZmZo0aZLatGmjtm3bysvLSzNmzNDUqVO1Zs0aff/999qzZ4/++te/as2aNZKkli1bymQyacuWLTpx4oTdk9CNRtAGAAAAgDogPj5eRUVFiomJUYsWLWzjzzzzjLp27arY2FhFRUXJYrFoyJAh1d5vgwYNtGnTJpWVlemuu+7S+PHjtXDhQruawYMHa+rUqZo8ebI6d+6s7OxsPfPMM3Y1DzzwgH7/+9+rT58+atq06WVfMebh4aGPPvpI//u//6vu3bvrwQcfVHR0tJYvX16zL+MisbGx2rx5swYOHKg2bdooLi5Obdu21datW+XsfP5C7eeee07z5s1TSkqK2rVrZ9smODhYknTbbbdp/vz5euqppxQQEKDJkydfV09XY7Je7oL9Oq6kpEQ+Pj4qLi6u8c35uPklZSY5uoV6JykqydEtAAAAXJezZ88qLy9PwcHBcnNzc3Q7uEld7XdUkxzKjDYAAAAAAAYiaAMAAAAAYCCCNgAAAAAABiJoAwAAAABgIII2AAAAAAAGImgDAAAAAGAggjYAAAAAAAYiaAMAAAAAYCCCNgAAAAAABiJoAwAAAMAtIj09XY0aNXJ0G/Wes6MbAAAAAIDalJmUeUOPF5UUVe1ak8l01fVxcXFKT0+/pj5atWqlxMREJSYm2saGDx+u++6775r2VxOVlZVavHix1qxZoyNHjsjd3V1t2rTRxIkT9eijj1ZrH5mZmerTp4+Kiopuuj8OELQBAAAAwEHy8/Nt/964caPmzZun3Nxc25i7u7uhx3N3dzd8n5eTlJSklStXavny5erWrZtKSkr0xRdfqKioqNaPXRdw6TgAAAAAOIjFYrEtPj4+MplMdmM7duxQeHi43NzcFBISovnz5+vcuXO27ZOSktSiRQuZzWYFBgbqiSeekCRFRUXpyJEjmjp1qkwmk23m/OJLx5OSktS5c2etXbtWrVq1ko+Pjx5++GGdPHnSVnPy5EmNGjVKnp6eatasmZYuXaqoqCi7mfKLbd68WQkJCXrooYcUHBysTp06KT4+XtOmTbPVWK1WLV68WCEhIXJ3d1enTp305ptvSpIOHz6sPn36SJIaN24sk8mksWPHXu/XfcMQtAEAAACgDvroo480evRoPfHEE/r666+1YsUKpaena+HChZKkN998U0uXLtWKFSt06NAhvfPOOwoLC5Mkvf3222revLkWLFig/Px8u5nzi33//fd65513tGXLFm3ZskVZWVl64YUXbOunTZumf/3rX3r33XeVkZGhnTt36t///vdVe7dYLPr444914sSJK9bMnTtXq1evVlpamg4cOKCpU6dq9OjRysrKUlBQkN566y1JUm5urvLz8/WXv/yl2t+do3HpOAAAAADUQQsXLtRTTz2luLg4SVJISIiee+45zZw5U88++6yOHj0qi8WimJgYubi4qEWLFrrrrrskSb6+vnJycpKXl5csFstVj1NVVaX09HR5eXlJksaMGaPt27dr4cKFOnnypNasWaP169crOjpakrR69WoFBgZedZ+pqal68MEHZbFY1KFDB0VGRmrw4MHq37+/JOn06dNKTU3Vxx9/rB49etjOb9euXVqxYoV69+4tX19fSZK/v/9Nd482M9oAAAAAUAfl5ORowYIFatiwoW2ZMGGC8vPzdebMGT300EMqLS1VSEiIJkyYoE2bNtldVl5drVq1soVsSWrWrJkKCwslST/88IMqKipsAV6SfHx8dMcdd1x1n+3bt9f+/fv16aef6tFHH9VPP/2kgQMHavz48ZKkr7/+WmfPnlXfvn3tzu+1117T999/X+NzqGuY0QYAAACAOqiqqkrz58/X0KFDL1nn5uamoKAg5ebmKiMjQ9u2bVNCQoJefPFFZWVlycXFpdrHubjWZDKpqqpK0vn7qC+M/dqF8atp0KCBunfvru7du2vq1Klat26dxowZozlz5tj2/9577+m2226z285sNle797qKoA0AAAAAdVDXrl2Vm5ur1q1bX7HG3d1dgwYN0qBBgzRp0iS1bdtW+/btU9euXeXq6qrKysrr6uH222+Xi4uLPv/8cwUFBUmSSkpKdOjQIfXu3btG+2rfvr2k85eNt2/fXmazWUePHr3iflxdXSXpus/BEQjaAAAAAFAHzZs3TwMGDFBQUJAeeughNWjQQF999ZX27dun559/Xunp6aqsrFRERIQ8PDy0du1aubu7q2XLlpLOXxK+Y8cOPfzwwzKbzWrSpEmNe/Dy8lJcXJz+9Kc/ydfXV/7+/nr22WfVoEGDq74D/MEHH1TPnj0VGRkpi8WivLw8zZ49W23atFHbtm3l7OysGTNmaOrUqaqqqlKvXr1UUlKi7OxsNWzYUHFxcWrZsqVMJpO2bNmi++67T+7u7mrYsOE1f583EvdoAwAAAEAdFBsbqy1btigjI0Pdu3fX3XffrdTUVFuQbtSokV599VX17NlTd955p7Zv367NmzfLz89PkrRgwQIdPnxYt99+u5o2bXrNfaSmpqpHjx4aMGCAYmJi1LNnT7Vr105ubm5X7X3z5s0aOHCg2rRpo7i4OLVt21Zbt26Vs/P5+d7nnntO8+bNU0pKitq1a2fbJjg4WJJ02223af78+XrqqacUEBCgyZMnX/M53Ggma3Uurq9jSkpK5OPjo+LiYnl7ezu6HdxgSZlJjm6h3kmKSnJ0CwAAANfl7NmzysvLU3Bw8FUDIK7f6dOnddttt2nJkiWKj493dDuGutrvqCY5lEvHAQAAAABXtGfPHn3zzTe66667VFxcrAULFkiSBg8e7ODO6i6CNgAAAADgqv785z8rNzdXrq6uCg8P186dO6/pnu9bBUEbAAAAAHBFXbp0UU5OjqPbuKnwMDQAAAAAAAxE0AYAAAAAwEAEbQAAAAD1xk34UiXUIUb9fgjaAAAAAG56Li4ukqQzZ844uBPczMrLyyVJTk5O17WfGj8M7T//+Y9mzZqlDz74QKWlpWrTpo1WrVql8PBwSef/AjB//nytXLlSRUVFioiI0F//+ld16NDBto+ysjLNmDFDr7/+ukpLSxUdHa1XXnlFzZs3v66TAQAAAHBrcnJyUqNGjVRYWChJ8vDwkMlkcnBXuJlUVVXpxIkT8vDwkLPz9T03vEZbFxUVqWfPnurTp48++OAD+fv76/vvv1ejRo1sNYsXL1ZqaqrS09PVpk0bPf/88+rbt69yc3Pl5eUlSUpMTNTmzZu1YcMG+fn5afr06RowYIBycnKu+y8HAAAAAG5NFotFkmxhG6ipBg0aqEWLFtf9R5oaBe1FixYpKChIq1evto21atXK9m+r1aply5Zpzpw5Gjp0qCRpzZo1CggI0Pr16zVx4kQVFxdr1apVWrt2rWJiYiRJ69atU1BQkLZt26bY2NjrOiEAAAAAtyaTyaRmzZrJ399fFRUVjm4HNyFXV1c1aHD9d1jXKGi/++67io2N1UMPPaSsrCzddtttSkhI0IQJEyRJeXl5KigoUL9+/WzbmM1m9e7dW9nZ2Zo4caJycnJUUVFhVxMYGKiOHTsqOzv7skG7rKxMZWVlts8lJSU1PlEAAAAAtwYnJyeulIVD1Sho//DDD0pLS9O0adP09NNP6/PPP9cTTzwhs9msRx55RAUFBZKkgIAAu+0CAgJ05MgRSVJBQYFcXV3VuHHjS2oubH+xlJQUzZ8/vyatoj77KdPRHQAAAADAFdVoTryqqkpdu3ZVcnKyunTpookTJ2rChAlKS0uzq7v4enar1fqb17hfrWb27NkqLi62LceOHatJ2wAAAAAA3DA1CtrNmjVT+/bt7cbatWuno0ePSvr/Dx+4eGa6sLDQNsttsVhUXl6uoqKiK9ZczGw2y9vb224BAAAAAKAuqlHQ7tmzp3Jzc+3Gvv32W7Vs2VKSFBwcLIvFooyMDNv68vJyZWVlKTIyUpIUHh4uFxcXu5r8/Hzt37/fVgMAAAAAwM2qRvdoT506VZGRkUpOTtawYcP0+eefa+XKlVq5cqWk85eMJyYmKjk5WaGhoQoNDVVycrI8PDw0cuRISZKPj4/i4+M1ffp0+fn5ydfXVzNmzFBYWJjtKeQAAAAAANysahS0u3fvrk2bNmn27NlasGCBgoODtWzZMo0aNcpWM3PmTJWWliohIUFFRUWKiIjQ1q1bbe/QlqSlS5fK2dlZw4YNU2lpqaKjo5Wens6TAQEAAAAANz2T1Wq1OrqJmiopKZGPj4+Ki4u5X/sWlLQxytEt1DtJwzMd3QIAAABQp9Ukh17/m7gBAAAAAIANQRsAAAAAAAMRtAEAAAAAMBBBGwAAAAAAAxG0AQAAAAAwUI1e7wXUCf9s5egO6p/hjm4AAAAAqD+Y0QYAAAAAwEAEbQAAAAAADETQBgAAAADAQARtAAAAAAAMRNAGAAAAAMBABG0AAAAAAAxE0AYAAAAAwEAEbQAAAAAADETQBgAAAADAQARtAAAAAAAMRNAGAAAAAMBABG0AAAAAAAxE0AYAAAAAwEAEbQAAAAAADETQBgAAAADAQARtAAAAAAAMRNAGAAAAAMBABG0AAAAAAAxE0AYAAAAAwEDOjm7glvBVkqM7AAAAAGCgzKRMR7dQr0QlRTm6BUMxow0AAAAAgIEI2gAAAAAAGIigDQAAAACAgbhH+wbITHN0BwAAAACAG4UZbQAAAAAADETQBgAAAADAQFw6DgAG43Ufxqtvr/wAAAD1GzPaAAAAAAAYiKANAAAAAICBuHQcAAAAdUpSZpKjW6h3kqKSHN0CcEshaAMA6jzuezcW97wDAFC7uHQcAAAAAAADMaN9A2SeOOzoFgAAAAAANwhBGwAAAHXLT5mO7gAArkuNLh1PSkqSyWSyWywWi2291WpVUlKSAgMD5e7urqioKB04cMBuH2VlZZoyZYqaNGkiT09PDRo0SD/++KMxZwMAAAAAgIPVeEa7Q4cO2rZtm+2zk5OT7d+LFy9Wamqq0tPT1aZNGz3//PPq27evcnNz5eXlJUlKTEzU5s2btWHDBvn5+Wn69OkaMGCAcnJy7PYFAABws+Ap2QCAX6tx0HZ2drabxb7AarVq2bJlmjNnjoYOHSpJWrNmjQICArR+/XpNnDhRxcXFWrVqldauXauYmBhJ0rp16xQUFKRt27YpNjb2Ok8HABwv83Cmo1uod6JaRTm6BQAAgGqrcdA+dOiQAgMDZTabFRERoeTkZIWEhCgvL08FBQXq16+frdZsNqt3797Kzs7WxIkTlZOTo4qKCruawMBAdezYUdnZ2VcM2mVlZSorK7N9LikpqWnbAK6CmRgAAADAODUK2hEREXrttdfUpk0b/fTTT3r++ecVGRmpAwcOqKCgQJIUEBBgt01AQICOHDkiSSooKJCrq6saN258Sc2F7S8nJSVF8+fPr0mrAAAAAFBruILNWFGKcnQLhqrRw9D69++vBx54QGFhYYqJidF7770n6fwl4heYTCa7baxW6yVjF/utmtmzZ6u4uNi2HDt2rCZtAwAAAABww1zX6708PT0VFhamQ4cOaciQIZLOz1o3a9bMVlNYWGib5bZYLCovL1dRUZHdrHZhYaEiIyOveByz2Syz2Xw9rQK4muWHHd1B/dKwlaM7AAAAte3UYUd3gDqsRjPaFysrK9PBgwfVrFkzBQcHy2KxKCMjw7a+vLxcWVlZthAdHh4uFxcXu5r8/Hzt37//qkEbAAAAAICbRY1mtGfMmKGBAweqRYsWKiws1PPPP6+SkhLFxcXJZDIpMTFRycnJCg0NVWhoqJKTk+Xh4aGRI0dKknx8fBQfH6/p06fLz89Pvr6+mjFjhu1SdAAAAAAAbnY1Cto//vijRowYoZ9//llNmzbV3XffrU8//VQtW7aUJM2cOVOlpaVKSEhQUVGRIiIitHXrVts7tCVp6dKlcnZ21rBhw1RaWqro6Gilp6fzDm0AAAAAQL1Qo6C9YcOGq643mUxKSkpSUlLSFWvc3Nz08ssv6+WXX67JoQEAt7KfMh3dQT0T5egG6h9+o8b6ZytHd1D/DHd0A8Ct5bru0QYAAAAAAPau66njAIDL4CmkxnNv5egOAAAAqo2gDQAAcL241Bl13VdJju4AuKVw6TgAAAAAAAYiaAMAAAAAYCCCNgAAAAAABiJoAwAAAABgIII2AAAAAAAGImgDAAAAAGAggjYAAAAAAAYiaAMAAAAAYCCCNgAAAAAABiJoAwAAAABgIII2AAAAAAAGImgDAAAAAGAggjYAAAAAAAZydnQDAAAAAGpXZpqjOwBuLcxoAwAAAABgIII2AAAAAAAG4tJxAECdl3nisKNbqFeiHN0AAAD1HDPaAAAAAAAYiKANAAAAAICBuHQcAAAAqOe4BQe4sZjRBgAAAADAQARtAAAAAAAMRNAGAAAAAMBABG0AAAAAAAxE0AYAAAAAwEAEbQAAAAAADETQBgAAAADAQARtAAAAAAAMRNAGAAAAAMBABG0AAAAAAAxE0AYAAAAAwEAEbQAAAAAADETQBgAAAADAQM6ObgAAANxgXyU5ugMAAOo1ZrQBAAAAADAQQRsAAAAAAAMRtAEAAAAAMNB1Be2UlBSZTCYlJibaxqxWq5KSkhQYGCh3d3dFRUXpwIEDdtuVlZVpypQpatKkiTw9PTVo0CD9+OOP19MKAAAAAAB1wjUH7d27d2vlypW688477cYXL16s1NRULV++XLt375bFYlHfvn118uRJW01iYqI2bdqkDRs2aNeuXTp16pQGDBigysrKaz8TAAAAAADqgGsK2qdOndKoUaP06quvqnHjxrZxq9WqZcuWac6cORo6dKg6duyoNWvW6MyZM1q/fr0kqbi4WKtWrdKSJUsUExOjLl26aN26ddq3b5+2bdtmzFkBAAAAAOAg1xS0J02apPvvv18xMTF243l5eSooKFC/fv1sY2azWb1791Z2drYkKScnRxUVFXY1gYGB6tixo63mYmVlZSopKbFbAAAAAACoi2r8Hu0NGzbo3//+t3bv3n3JuoKCAklSQECA3XhAQICOHDliq3F1dbWbCb9Qc2H7i6WkpGj+/Pk1bRUAAAAAgBuuRjPax44d05NPPql169bJzc3tinUmk8nus9VqvWTsYlermT17toqLi23LsWPHatI2AAAAAAA3TI2Cdk5OjgoLCxUeHi5nZ2c5OzsrKytLL730kpydnW0z2RfPTBcWFtrWWSwWlZeXq6io6Io1FzObzfL29rZbAAAAAACoi2oUtKOjo7Vv3z7t3bvXtnTr1k2jRo3S3r17FRISIovFooyMDNs25eXlysrKUmRkpCQpPDxcLi4udjX5+fnav3+/rQYAAAAAgJtVje7R9vLyUseOHe3GPD095efnZxtPTExUcnKyQkNDFRoaquTkZHl4eGjkyJGSJB8fH8XHx2v69Ony8/OTr6+vZsyYobCwsEsergYAAIyXmeboDgAAqN9q/DC03zJz5kyVlpYqISFBRUVFioiI0NatW+Xl5WWrWbp0qZydnTVs2DCVlpYqOjpa6enpcnJyMrodAAAAAABuKJPVarU6uomaKikpkY+Pj4qLi2+K+7WTHhzr6BYAALCJatrK0S3UO5knDju6BQC4qSW9me7oFn5TTXLoNb1HGwAAAAAAXB5BGwAAAAAAAxG0AQAAAAAwEEEbAAAAAAADGf7UcQAAULfx4C4AAGoXM9oAAAAAABiIoA0AAAAAgIEI2gAAAAAAGIigDQAAAACAgQjaAAAAAAAYiKANAAAAAICBCNoAAAAAABiIoA0AAAAAgIEI2gAAAAAAGIigDQAAAACAgQjaAAAAAAAYiKANAAAAAICBCNoAAAAAABiIoA0AAAAAgIEI2gAAAAAAGIigDQAAAACAgQjaAAAAAAAYiKANAAAAAICBCNoAAAAAABiIoA0AAAAAgIEI2gAAAAAAGIigDQAAAACAgQjaAAAAAAAYiKANAAAAAICBCNoAAAAAABiIoA0AAAAAgIEI2gAAAAAAGIigDQAAAACAgQjaAAAAAAAYiKANAAAAAICBCNoAAAAAABiIoA0AAAAAgIEI2gAAAAAAGIigDQAAAACAgQjaAAAAAAAYiKANAAAAAICBahS009LSdOedd8rb21ve3t7q0aOHPvjgA9t6q9WqpKQkBQYGyt3dXVFRUTpw4IDdPsrKyjRlyhQ1adJEnp6eGjRokH788UdjzgYAAAAAAAerUdBu3ry5XnjhBX3xxRf64osv9Lvf/U6DBw+2henFixcrNTVVy5cv1+7du2WxWNS3b1+dPHnSto/ExERt2rRJGzZs0K5du3Tq1CkNGDBAlZWVxp4ZAAAAAAAOYLJardbr2YGvr69efPFFjRs3ToGBgUpMTNSsWbMknZ+9DggI0KJFizRx4kQVFxeradOmWrt2rYYPHy5JOn78uIKCgvT+++8rNja2WscsKSmRj4+PiouL5e3tfT3t3xBJD451dAsAAAAAUGclvZnu6BZ+U01y6DXfo11ZWakNGzbo9OnT6tGjh/Ly8lRQUKB+/frZasxms3r37q3s7GxJUk5OjioqKuxqAgMD1bFjR1vN5ZSVlamkpMRuAQAAAACgLqpx0N63b58aNmwos9msxx9/XJs2bVL79u1VUFAgSQoICLCrDwgIsK0rKCiQq6urGjdufMWay0lJSZGPj49tCQoKqmnbAAAAAADcEDUO2nfccYf27t2rTz/9VH/84x8VFxenr7/+2rbeZDLZ1Vut1kvGLvZbNbNnz1ZxcbFtOXbsWE3bBgAAAADghqhx0HZ1dVXr1q3VrVs3paSkqFOnTvrLX/4ii8UiSZfMTBcWFtpmuS0Wi8rLy1VUVHTFmssxm822J51fWAAAAAAAqIuu+z3aVqtVZWVlCg4OlsViUUZGhm1deXm5srKyFBkZKUkKDw+Xi4uLXU1+fr72799vqwEAAAAA4GbmXJPip59+Wv3791dQUJBOnjypDRs2KDMzUx9++KFMJpMSExOVnJys0NBQhYaGKjk5WR4eHho5cqQkycfHR/Hx8Zo+fbr8/Pzk6+urGTNmKCwsTDExMbVyggAAAAAA3Eg1Cto//fSTxowZo/z8fPn4+OjOO+/Uhx9+qL59+0qSZs6cqdLSUiUkJKioqEgRERHaunWrvLy8bPtYunSpnJ2dNWzYMJWWlio6Olrp6elycnIy9swAAAAAAHCA636PtiPwHm0AAAAAqD94jzYAAAAAALgigjYAAAAAAAYiaAMAAAAAYCCCNgAAAAAABiJoAwAAAABgIII2AAAAAAAGImgDAAAAAGAggjYAAAAAAAYiaAMAAAAAYCCCNgAAAAAABiJoAwAAAABgIII2AAAAAAAGImgDAAAAAGAggjYAAAAAAAYiaAMAAAAAYCCCNgAAAAAABiJoAwAAAABgIII2AAAAAAAGImgDAAAAAGAggjYAAAAAAAYiaAMAAAAAYCCCNgAAAAAABiJoAwAAAABgIII2AAAAAAAGImgDAAAAAGAggjYAAAAAAAYiaAMAAAAAYCCCNgAAAAAABiJoAwAAAABgIII2AAAAAAAGImgDAAAAAGAggjYAAAAAAAYiaAMAAAAAYCCCNgAAAAAABiJoAwAAAABgIII2AAAAAAAGImgDAAAAAGAggjYAAAAAAAYiaAMAAAAAYCCCNgAAAAAABqpR0E5JSVH37t3l5eUlf39/DRkyRLm5uXY1VqtVSUlJCgwMlLu7u6KionTgwAG7mrKyMk2ZMkVNmjSRp6enBg0apB9//PH6zwYAAAAAAAerUdDOysrSpEmT9OmnnyojI0Pnzp1Tv379dPr0aVvN4sWLlZqaquXLl2v37t2yWCzq27evTp48aatJTEzUpk2btGHDBu3atUunTp3SgAEDVFlZadyZAQAAAADgACar1Wq91o1PnDghf39/ZWVl6d5775XValVgYKASExM1a9YsSednrwMCArRo0SJNnDhRxcXFatq0qdauXavhw4dLko4fP66goCC9//77io2N/c3jlpSUyMfHR8XFxfL29r7W9m+YpAfHOroFAAAAAKizkt5Md3QLv6kmOfS67tEuLi6WJPn6+kqS8vLyVFBQoH79+tlqzGazevfurezsbElSTk6OKioq7GoCAwPVsWNHW83FysrKVFJSYrcAAAAAAFAXXXPQtlqtmjZtmnr16qWOHTtKkgoKCiRJAQEBdrUBAQG2dQUFBXJ1dVXjxo2vWHOxlJQU+fj42JagoKBrbRsAAAAAgFp1zUF78uTJ+uqrr/T6669fss5kMtl9tlqtl4xd7Go1s2fPVnFxsW05duzYtbYNAAAAAECtuqagPWXKFL377rv65JNP1Lx5c9u4xWKRpEtmpgsLC22z3BaLReXl5SoqKrpizcXMZrO8vb3tFgAAAAAA6qIaBW2r1arJkyfr7bff1scff6zg4GC79cHBwbJYLMrIyLCNlZeXKysrS5GRkZKk8PBwubi42NXk5+dr//79thoAAAAAAG5WzjUpnjRpktavX6///u//lpeXl23m2sfHR+7u7jKZTEpMTFRycrJCQ0MVGhqq5ORkeXh4aOTIkbba+Ph4TZ8+XX5+fvL19dWMGTMUFhammJgY488QAAAAAIAbqEZBOy0tTZIUFRVlN7569WqNHTtWkjRz5kyVlpYqISFBRUVFioiI0NatW+Xl5WWrX7p0qZydnTVs2DCVlpYqOjpa6enpcnJyur6zAQAAAADAwa7rPdqOwnu0AQAAAKD+4D3aAAAAAADgigjaAAAAAAAYiKANAAAAAICBCNoAAAAAABiIoA0AAAAAgIEI2gAAAAAAGIigDQAAAACAgQjaAAAAAAAYiKANAAAAAICBCNoAAAAAABiIoA0AAAAAgIEI2gAAAAAAGIigDQAAAACAgQjaAAAAAAAYiKANAAAAAICBCNoAAAAAABiIoA0AAAAAgIEI2gAAAAAAGIigDQAAAACAgQjaAAAAAAAYiKANAAAAAICBCNoAAAAAABiIoA0AAAAAgIEI2gAAAAAAGIigDQAAAACAgQjaAAAAAAAYiKANAAAAAICBCNoAAAAAABiIoA0AAAAAgIEI2gAAAAAAGIigDQAAAACAgQjaAAAAAAAYiKANAAAAAICBCNoAAAAAABiIoA0AAAAAgIEI2gAAAAAAGIigDQAAAACAgQjaAAAAAAAYiKANAAAAAICBCNoAAAAAABioxkF7x44dGjhwoAIDA2UymfTOO+/YrbdarUpKSlJgYKDc3d0VFRWlAwcO2NWUlZVpypQpatKkiTw9PTVo0CD9+OOP13UiAAAAAADUBTUO2qdPn1anTp20fPnyy65fvHixUlNTtXz5cu3evVsWi0V9+/bVyZMnbTWJiYnatGmTNmzYoF27dunUqVMaMGCAKisrr/1MAAAAAACoA5xrukH//v3Vv3//y66zWq1atmyZ5syZo6FDh0qS1qxZo4CAAK1fv14TJ05UcXGxVq1apbVr1yomJkaStG7dOgUFBWnbtm2KjY29jtMBAAAAAMCxDL1HOy8vTwUFBerXr59tzGw2q3fv3srOzpYk5eTkqKKiwq4mMDBQHTt2tNVcrKysTCUlJXYLAAAAAAB1kaFBu6CgQJIUEBBgNx4QEGBbV1BQIFdXVzVu3PiKNRdLSUmRj4+PbQkKCjKybQAAAAAADFMrTx03mUx2n61W6yVjF7tazezZs1VcXGxbjh07ZlivAAAAAAAYydCgbbFYJOmSmenCwkLbLLfFYlF5ebmKioquWHMxs9ksb29vuwUAAAAAgLrI0KAdHBwsi8WijIwM21h5ebmysrIUGRkpSQoPD5eLi4tdTX5+vvbv32+rAQAAAADgZlXjp46fOnVK3333ne1zXl6e9u7dK19fX7Vo0UKJiYlKTk5WaGioQkNDlZycLA8PD40cOVKS5OPjo/j4eE2fPl1+fn7y9fXVjBkzFBYWZnsKOQAAAAAAN6saB+0vvvhCffr0sX2eNm2aJCkuLk7p6emaOXOmSktLlZCQoKKiIkVERGjr1q3y8vKybbN06VI5Oztr2LBhKi0tVXR0tNLT0+Xk5GTAKQEAAAAA4Dgmq9VqdXQTNVVSUiIfHx8VFxffFPdrJz041tEtAAAAAECdlfRmuqNb+E01yaG18tRxAAAAAABuVQRtAAAAAAAMRNAGAAAAAMBABG0AAAAAAAxE0AYAAAAAwEAEbQAAAAAADETQBgAAAADAQARtAAAAAAAMRNAGAAAAAMBABG0AAAAAAAxE0AYAAAAAwEAEbQAAAAAADETQBgAAAADAQARtAAAAAAAMRNAGAAAAAMBABG0AAAAAAAxE0AYAAAAAwEAEbQAAAAAADETQBgAAAADAQARtAAAAAAAMRNAGAAAAAMBABG0AAAAAAAxE0AYAAAAAwEAEbQAAAAAADETQBgAAAADAQARtAAAAAAAMRNAGAAAAAMBABG0AAAAAAAxE0AYAAAAAwEAEbQAAAAAADETQBgAAAADAQARtAAAAAAAMRNAGAAAAAMBABG0AAAAAAAxE0AYAAAAAwEAEbQAAAAAADETQBgAAAADAQARtAAAAAAAMRNAGAAAAAMBABG0AAAAAAAzk0KD9yiuvKDg4WG5ubgoPD9fOnTsd2Q4AAAAAANfNYUF748aNSkxM1Jw5c7Rnzx7dc8896t+/v44ePeqolgAAAAAAuG4OC9qpqamKj4/X+PHj1a5dOy1btkxBQUFKS0tzVEsAAAAAAFw3hwTt8vJy5eTkqF+/fnbj/fr1U3Z2tiNaAgAAAADAEM6OOOjPP/+syspKBQQE2I0HBASooKDgkvqysjKVlZXZPhcXF0uSSkpKardRg5RVlDu6BQAAAACos26GbHehR6vV+pu1DgnaF5hMJrvPVqv1kjFJSklJ0fz58y8ZDwoKqrXeAAAAAAA3xgs+rzu6hWo7efKkfHx8rlrjkKDdpEkTOTk5XTJ7XVhYeMkstyTNnj1b06ZNs32uqqrS//7v/8rPz++ywbwuKSkpUVBQkI4dOyZvb29HtwPUGn7ruJXwe8ethN87biX83nE1VqtVJ0+eVGBg4G/WOiRou7q6Kjw8XBkZGfrDH/5gG8/IyNDgwYMvqTebzTKbzXZjjRo1qu02DeXt7c3/WHFL4LeOWwm/d9xK+L3jVsLvHVfyWzPZFzjs0vFp06ZpzJgx6tatm3r06KGVK1fq6NGjevzxxx3VEgAAAAAA181hQXv48OH65ZdftGDBAuXn56tjx456//331bJlS0e1BAAAAADAdXPow9ASEhKUkJDgyBZqndls1rPPPnvJpe9AfcNvHbcSfu+4lfB7x62E3zuMYrJW59nkAAAAAACgWho4ugEAAAAAAOoTgjYAAAAAAAYiaAMAAAAAYCCCNgAAAAAABiJo16JXXnlFwcHBcnNzU3h4uHbu3OnolgDDpaSkqHv37vLy8pK/v7+GDBmi3NxcR7cF3BApKSkymUxKTEx0dCtArfjPf/6j0aNHy8/PTx4eHurcubNycnIc3RZguHPnzmnu3LkKDg6Wu7u7QkJCtGDBAlVVVTm6NdykCNq1ZOPGjUpMTNScOXO0Z88e3XPPPerfv7+OHj3q6NYAQ2VlZWnSpEn69NNPlZGRoXPnzqlfv346ffq0o1sDatXu3bu1cuVK3XnnnY5uBagVRUVF6tmzp1xcXPTBBx/o66+/1pIlS9SoUSNHtwYYbtGiRfrb3/6m5cuX6+DBg1q8eLFefPFFvfzyy45uDTcpXu9VSyIiItS1a1elpaXZxtq1a6chQ4YoJSXFgZ0BtevEiRPy9/dXVlaW7r33Xke3A9SKU6dOqWvXrnrllVf0/PPPq3Pnzlq2bJmj2wIM9dRTT+lf//oXV+ThljBgwAAFBARo1apVtrEHHnhAHh4eWrt2rQM7w82KGe1aUF5erpycHPXr189uvF+/fsrOznZQV8CNUVxcLEny9fV1cCdA7Zk0aZLuv/9+xcTEOLoVoNa8++676tatmx566CH5+/urS5cuevXVVx3dFlArevXqpe3bt+vbb7+VJH355ZfatWuX7rvvPgd3hpuVs6MbqI9+/vlnVVZWKiAgwG48ICBABQUFDuoKqH1Wq1XTpk1Tr1691LFjR0e3A9SKDRs26N///rd2797t6FaAWvXDDz8oLS1N06ZN09NPP63PP/9cTzzxhMxmsx555BFHtwcYatasWSouLlbbtm3l5OSkyspKLVy4UCNGjHB0a7hJEbRrkclksvtstVovGQPqk8mTJ+urr77Srl27HN0KUCuOHTumJ598Ulu3bpWbm5uj2wFqVVVVlbp166bk5GRJUpcuXXTgwAGlpaURtFHvbNy4UevWrdP69evVoUMH7d27V4mJiQoMDFRcXJyj28NNiKBdC5o0aSInJ6dLZq8LCwsvmeUG6ospU6bo3Xff1Y4dO9S8eXNHtwPUipycHBUWFio8PNw2VllZqR07dmj58uUqKyuTk5OTAzsEjNOsWTO1b9/ebqxdu3Z66623HNQRUHv+9Kc/6amnntLDDz8sSQoLC9ORI0eUkpJC0MY14R7tWuDq6qrw8HBlZGTYjWdkZCgyMtJBXQG1w2q1avLkyXr77bf18ccfKzg42NEtAbUmOjpa+/bt0969e21Lt27dNGrUKO3du5eQjXqlZ8+el7yu8dtvv1XLli0d1BFQe86cOaMGDeyjkZOTE6/3wjVjRruWTJs2TWPGjFG3bt3Uo0cPrVy5UkePHtXjjz/u6NYAQ02aNEnr16/Xf//3f8vLy8t2JYePj4/c3d0d3B1gLC8vr0ueP+Dp6Sk/Pz+eS4B6Z+rUqYqMjFRycrKGDRumzz//XCtXrtTKlSsd3RpguIEDB2rhwoVq0aKFOnTooD179ig1NVXjxo1zdGu4SfF6r1r0yiuvaPHixcrPz1fHjh21dOlSXneEeudKzx1YvXq1xo4de2ObARwgKiqK13uh3tqyZYtmz56tQ4cOKTg4WNOmTdOECRMc3RZguJMnT+qZZ57Rpk2bVFhYqMDAQI0YMULz5s2Tq6uro9vDTYigDQAAAACAgbhHGwAAAAAAAxG0AQAAAAAwEEEbAAAAAAADEbQBAAAAADAQQRsAAAAAAAMRtAEAAAAAMBBBGwAAAAAAAxG0AQAwQFJSkjp37mz7PHbsWA0ZMuSG93H48GGZTCbt3bv3mrbPzMyUyWTS//3f/xnaV3VcS+9GfM+OPGcAQP1E0AYA1Ftjx46VyWSSyWSSi4uLQkJCNGPGDJ0+fbrWj/2Xv/xF6enp1aq93nAMAADqFmdHNwAAQG36/e9/r9WrV6uiokI7d+7U+PHjdfr0aaWlpV1SW1FRIRcXF0OO6+PjY8h+AADAzYcZbQBAvWY2m2WxWBQUFKSRI0dq1KhReueddyT9/8u9//73vyskJERms1lWq1XFxcV67LHH5O/vL29vb/3ud7/Tl19+abffF154QQEBAfLy8lJ8fLzOnj1rt/7iS5qrqqq0aNEitW7dWmazWS1atNDChQslScHBwZKkLl26yGQyKSoqyrbd6tWr1a5dO7m5ualt27Z65ZVX7I7z+eefq0uXLnJzc1O3bt20Z8+e3/xOysrKNHPmTAUFBclsNis0NFSrVq26bO0vv/yiESNGqHnz5vLw8FBYWJhef/11u5o333xTYWFhcnd3l5+fn2JiYmxXDWRmZuquu+6Sp6enGjVqpJ49e+rIkSO/2aMkVVZWKj4+XsHBwXJ3d9cdd9yhv/zlL5etnT9/vu2/18SJE1VeXm5bZ7VatXjxYoWEhMjd3V2dOnXSm2++Wa0eAAC4FsxoAwBuKe7u7qqoqLB9/u677/TGG2/orbfekpOTkyTp/vvvl6+vr95//335+PhoxYoVio6O1rfffitfX1+98cYbevbZZ/XXv/5V99xzj9auXauXXnpJISEhVzzu7Nmz9eqrr2rp0qXq1auX8vPz9c0330g6H5bvuusubdu2TR06dJCrq6sk6dVXX9Wzzz6r5cuXq0uXLtqzZ48mTJggT09PxcXF6fTp0xowYIB+97vfad26dcrLy9OTTz75m9/BI488ov/5n//RSy+9pE6dOikvL08///zzZWvPnj2r8PBwzZo1S97e3nrvvfc0ZswYhYSEKCIiQvn5+RoxYoQWL16sP/zhDzp58qR27twpq9Wqc+fOaciQIZowYYJef/11lZeX6/PPP5fJZKrWf6uqqio1b95cb7zxhpo0aaLs7Gw99thjatasmYYNG2ar2759u9zc3PTJJ5/o8OHDevTRR9WkSRPbHzLmzp2rt99+W2lpaQoNDdWOHTs0evRoNW3aVL17965WLwAA1IgVAIB6Ki4uzjp48GDb588++8zq5+dnHTZsmNVqtVqfffZZq4uLi7WwsNBWs337dqu3t7f17Nmzdvu6/fbbrStWrLBarVZrjx49rI8//rjd+oiICGunTp0ue+ySkhKr2Wy2vvrqq5ftMy8vzyrJumfPHrvxoKAg6/r16+3GnnvuOWuPHj2sVqvVumLFCquvr6/19OnTtvVpaWmX3dcFubm5VknWjIyMy67/5JNPrJKsRUVFl11vtVqt9913n3X69OlWq9VqzcnJsUqyHj58+JK6X375xSrJmpmZecV9/dqVvodfS0hIsD7wwAO2z3FxcZf9Dho2bGitrKy0njp1yurm5mbNzs622098fLx1xIgRVqu1eucMAEBNMKMNAKjXtmzZooYNG+rcuXOqqKjQ4MGD9fLLL9vWt2zZUk2bNrV9zsnJ0alTp+Tn52e3n9LSUn3//feSpIMHD+rxxx+3W9+jRw998sknl+3h4MGDKisrU3R0dLX7PnHihI4dO6b4+HhNmDDBNn7u3Dnb/d8HDx5Up06d5OHhYdfH1ezdu1dOTk7VnsmtrKzUCy+8oI0bN+o///mPysrKVFZWJk9PT0lSp06dFB0drbCwMMXGxqpfv3568MEH1bhxY/n6+mrs2LGKjY1V3759FRMTo2HDhqlZs2bV/h7+9re/6b/+67905MgRlZaWqry83O7p7hd6uPg7OHXqlI4dO6bCwkKdPXtWffv2tdumvLxcXbp0qXYfAADUBEEbAFCv9enTR2lpaXJxcVFgYOAlDzu7EBgvqKqqUrNmzZSZmXnJvho1anRNPbi7u9d4m6qqKknnLx+PiIiwW3fhEner1VrrvSxZskRLly7VsmXLFBYWJk9PTyUmJtrugXZyclJGRoays7O1detWvfzyy5ozZ44+++wzBQcHa/Xq1XriiSf04YcfauPGjZo7d64yMjJ09913/+ax33jjDU2dOlVLlixRjx495OXlpRdffFGfffZZtXo3mUy27/G9997TbbfdZrfebDbX6LsAAKC6eBgaAKBe8/T0VOvWrdWyZctqPVG8a9euKigokLOzs1q3bm23NGnSRJLUrl07ffrpp3bbXfz510JDQ+Xu7q7t27dfdv2Fe7IrKyttYwEBAbrtttv0ww8/XNLHhYentW/fXl9++aVKS0ur1YckhYWFqaqqSllZWVetu2Dnzp0aPHiwRo8erU6dOikkJESHDh2yqzGZTOrZs6fmz5+vPXv2yNXVVZs2bbKt79Kli2bPnq3s7Gx17NhR69evr/axIyMjlZCQoC5duqh169a2qwp+7XLfQcOGDdW8eXO1b99eZrNZR48eveR7DAoKqlYfAADUFDPaAAD8SkxMjHr06KEhQ4Zo0aJFuuOOO3T8+HG9//77GjJkiLp166Ynn3xScXFx6tatm3r16qV//OMfOnDgwBUfhubm5qZZs2Zp5syZcnV1Vc+ePXXixAkdOHBA8fHx8vf3l7u7uz788EM1b95cbm5u8vHxUVJSkp544gl5e3urf//+Kisr0xdffKGioiJNmzZNI0eO1Jw5cxQfH6+5c+fq8OHD+vOf/3zV82vVqpXi4uI0btw428PQjhw5osLCQrsHjF3QunVrvfXWW8rOzlbjxo2VmpqqgoICtWvXTpL02Wefafv27erXr5/8/f312Wef6cSJE2rXrp3y8vK0cuVKDRo0SIGBgcrNzdW3336rRx55pFr/LVq3bq3XXntNH330kYKDg7V27Vrt3r3b9oeGC8rLy23fwZEjR/Tss89q8uTJatCggby8vDRjxgxNnTpVVVVV6tWrl0pKSpSdna2GDRsqLi6uWr0AAFATBG0AAH7FZDLp/fff15w5czRu3DidOHFCFotF9957rwICAiRJw4cP1/fff69Zs2bp7NmzeuCBB/THP/5RH3300RX3+8wzz8jZ2Vnz5s3T8ePH1axZM9t93s7OznrppZe0YMECzZs3T/fcc48yMzM1fvx4eXh46MUXX9TMmTPl6empsLAwJSYmSpIaNmyozZs36/HHH1eXLl3Uvn17LVq0SA888MBVzzEtLU1PP/20EhIS9Msvv6hFixZ6+umnr9h3Xl6eYmNj5eHhoccee0xDhgxRcXGxJMnb21s7duzQsmXLVFJSopYtW2rJkiXq37+/fvrpJ33zzTdas2aNfvnlFzVr1kyTJ0/WxIkTq/Xf4vHHH9fevXs1fPhwmUwmjRgxQgkJCfrggw/s6qKjoxUaGqp7771XZWVlevjhh5WUlGRb/9xzz8nf318pKSn64Ycf1KhRI3Xt2vWK5wwAwPUyWa/lBi8AAAAAAHBZ3KMNAAAAAICBCNoAAAAAABiIoA0AAAAAgIEI2gAAAAAAGIigDQAAAACAgQjaAAAAAAAYiKANAAAAAICBCNoAAAAAABiIoA0AAAAAgIEI2gAAAAAAGIigDQAAAACAgQjaAAAAAAAY6P8BQvZdCB4PPawAAAAASUVORK5CYII=",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"Histograms_top_choice\"\\\n",
+ " + \"_\" + path_dict['run_label']\n",
+ "\n",
+ "plotPredictionHistogram(y_choice_tra,\n",
+ " y_prediction_b=y_choice_val,\n",
+ " y_prediction_c=y_choice_tes,\n",
+ " label_a=\"Training Set\",\n",
+ " label_b=\"Validation Set\",\n",
+ " label_c=\"Testing Set\",\n",
+ " figsize=(12, 5),\n",
+ " alpha=0.5,\n",
+ " xlabel_plot=\"Predicted class label\",\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 4: Histograms of the number of images for which the top-choice class was each number 0 through 9. Each histogram is for a different data set used during model training --- training data, validation data, and test data. Note that these are overlapping histograms, not stacked. Please compare to Figure 3, which shows the distributions of true class labels for each data set. Consider which classes are represented differently between the true labels and the predicted labels."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 44,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:37.114356Z",
+ "iopub.status.busy": "2025-02-09T23:13:37.114005Z",
+ "iopub.status.idle": "2025-02-09T23:13:38.178779Z",
+ "shell.execute_reply": "2025-02-09T23:13:38.178135Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:37.114322Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"Histograms_class_probabilities\"\\\n",
+ " + \"_\" + path_dict['run_label']\n",
+ "\n",
+ "plotPredictionHistogram(y_prob_tra,\n",
+ " y_prediction_b=y_prob_val,\n",
+ " y_prediction_c=y_prob_tes,\n",
+ " title_a='Training Set',\n",
+ " title_b='Validation Set',\n",
+ " title_c='Testing Set',\n",
+ " figsize=(15, 4),\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 5: Histograms of the number of images (y-axis) that had a probability (x-axis) of being each class 0 through 9 (light to dark shades)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In both of Figure 3 and 4, the histograms show very similar shapes across the classification categories.\n",
+ "This is a good sign because it indicates the model is not heavily biased toward a particular class."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.3. Generalization Error\n",
+ "\n",
+ "The primary task in optimizing a network is to minimize the Generalization Error. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 4.3.1. Loss History: History of Loss and Accuracy during Training"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot the loss history for the validation and training sets. We reserve the test set for a 'blind' analysis."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 45,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:38.179887Z",
+ "iopub.status.busy": "2025-02-09T23:13:38.179456Z",
+ "iopub.status.idle": "2025-02-09T23:13:38.694470Z",
+ "shell.execute_reply": "2025-02-09T23:13:38.693818Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:38.179865Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"LossHistory\"\\\n",
+ " + \"_\"\\\n",
+ " + path_dict['run_label']\n",
+ "\n",
+ "plotLossHistory(history,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 6: In the top panel, the loss history as a function of epoch for the training and validation sets decreases with time, as it should as the model improves. In the bottom panel, the loss residual (validation - training) shows a dip at epoch 2, indicating the model caused a divergence in the training and validation set classifications, but that this was rectified in later epochs."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2024-10-26T18:50:41.751869Z",
+ "iopub.status.busy": "2024-10-26T18:50:41.751239Z",
+ "iopub.status.idle": "2024-10-26T18:50:41.757103Z",
+ "shell.execute_reply": "2024-10-26T18:50:41.756503Z",
+ "shell.execute_reply.started": "2024-10-26T18:50:41.751843Z"
+ }
+ },
+ "source": [
+ "### 4.3.2. Confusion Matrix: Bias in Trained Model?"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Compute and plot the confusion matrices for the training, validation, and test samples (left, right, middle)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 46,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:38.695466Z",
+ "iopub.status.busy": "2025-02-09T23:13:38.695205Z",
+ "iopub.status.idle": "2025-02-09T23:13:40.428933Z",
+ "shell.execute_reply": "2025-02-09T23:13:40.427883Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:38.695446Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')\n",
+ "figsize = (15, 3)\n",
+ "linewidths = 0.01\n",
+ "linecolor = 'white'\n",
+ "ylabel = \"Predicted Label\"\n",
+ "xlabel = \"True Label\"\n",
+ "\n",
+ "cm_tra = confusion_matrix(y_true_tra, y_choice_tra)\n",
+ "cm_val = confusion_matrix(y_true_val, y_choice_val)\n",
+ "cm_tes = confusion_matrix(y_true_tes, y_choice_tes)\n",
+ "\n",
+ "df_cm_tra = pd.DataFrame(cm_tra / np.sum(cm_tra, axis=1)[:, None],\n",
+ " index=[i for i in classes],\n",
+ " columns=[i for i in classes])\n",
+ "df_cm_val = pd.DataFrame(cm_val / np.sum(cm_val, axis=1)[:, None],\n",
+ " index=[i for i in classes],\n",
+ " columns=[i for i in classes])\n",
+ "df_cm_tes = pd.DataFrame(cm_tes / np.sum(cm_tes, axis=1)[:, None],\n",
+ " index=[i for i in classes],\n",
+ " columns=[i for i in classes])\n",
+ "\n",
+ "fig, (axa, axb, axc) = plt.subplots(1, 3, figsize=figsize)\n",
+ "fig.subplots_adjust(wspace=0.5)\n",
+ "\n",
+ "ax1 = sns.heatmap(df_cm_tra, annot=False, linewidths=linewidths,\n",
+ " linecolor=linecolor, square=True, ax=axa)\n",
+ "_ = ax1.set(xlabel=xlabel, ylabel=ylabel, title=\"Training Data\")\n",
+ "\n",
+ "ax2 = sns.heatmap(df_cm_val, annot=False, linewidths=linewidths,\n",
+ " linecolor=linecolor, square=True, ax=axb)\n",
+ "_ = ax2.set(xlabel=xlabel, ylabel=ylabel, title=\"Validation Data\")\n",
+ "\n",
+ "ax3 = sns.heatmap(df_cm_tes, annot=False, linewidths=linewidths,\n",
+ " linecolor=linecolor, square=True, ax=axc)\n",
+ "_ = ax3.set(xlabel=xlabel, ylabel=ylabel, title=\"Test Data\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 7: Confusion matrices for the training, validation, and test data sets (left to right). Each cell in the matrix shows the fraction (see the color bars) of the total objects with that true label that have been predicted to be a given label. The diagonal represents images that were correctly classified. All off-diagonal cells represent false positives (lower left) or false negatives (upper right). Consider some examples. First, the class \"0\" is almost always predicted to be \"0\" with a lighter color in the top-most, left-most cell; all the other cells in the column are completely dark. Second, consider the cell that represents a prediction \"4\", when the true label is \"9\": that cell is not completely dark. A \"4\" has a similar morphology or shape as a \"9.\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "vvAqrZwjVYBt"
+ },
+ "source": [
+ "### 4.3.4. Investigating predictions"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 4.3.4.1. Define of classification metrics"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-01-26T21:19:35.897282Z",
+ "iopub.status.busy": "2025-01-26T21:19:35.896488Z",
+ "iopub.status.idle": "2025-01-26T21:19:36.705672Z",
+ "shell.execute_reply": "2025-01-26T21:19:36.705022Z",
+ "shell.execute_reply.started": "2025-01-26T21:19:35.897242Z"
+ }
+ },
+ "source": [
+ "Consider the example of true class label being \"2\". Then, we define the following metrics.\n",
+ "\n",
+ "* `True Positive (TP)`: correctly classified input digit image --- e.g., a \"2\" classified as \"2\"\n",
+ "* `False Positive (FP)`: another digit classified as \"2\"\n",
+ "* `True Negative (TN)`: another digit classified as another digit); and\n",
+ "* `False Negative (FN)`: a \"2\" classified as something other than \"2\"."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 4.3.4.2. Explore the classification of the training data for an example class value"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Investigate the case in which the true digit label is \"2\"."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 47,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:40.430615Z",
+ "iopub.status.busy": "2025-02-09T23:13:40.430250Z",
+ "iopub.status.idle": "2025-02-09T23:13:40.575406Z",
+ "shell.execute_reply": "2025-02-09T23:13:40.574407Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:40.430564Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "class_value = 2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Find all objects that have that class value. Obtain indices for the tp's, fp's, tn's, and fn's. Create subsets of the data according to those indices."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:40.576471Z",
+ "iopub.status.busy": "2025-02-09T23:13:40.576232Z",
+ "iopub.status.idle": "2025-02-09T23:13:40.727095Z",
+ "shell.execute_reply": "2025-02-09T23:13:40.725923Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:40.576451Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "TP count: 439\n",
+ "FP count: 58\n",
+ "TN count: 4435\n",
+ "FN count: 68\n"
+ ]
+ }
+ ],
+ "source": [
+ "ind_class_tp_tra = np.where((y_true_tra == class_value)\n",
+ " & (y_choice_tra == class_value))[0]\n",
+ "\n",
+ "ind_class_fp_tra = np.where((y_true_tra != class_value)\n",
+ " & (y_choice_tra == class_value))[0]\n",
+ "\n",
+ "ind_class_tn_tra = np.where((y_true_tra != class_value)\n",
+ " & (y_choice_tra != class_value))[0]\n",
+ "\n",
+ "ind_class_fn_tra = np.where((y_true_tra == class_value)\n",
+ " & (y_choice_tra != class_value))[0]\n",
+ "\n",
+ "x_tra_tp = x_tra[ind_class_tp_tra]\n",
+ "y_true_tra_tp = y_true_tra[ind_class_tp_tra]\n",
+ "y_choice_tra_tp = y_choice_tra[ind_class_tp_tra]\n",
+ "\n",
+ "x_tra_fp = x_tra[ind_class_fp_tra]\n",
+ "y_true_tra_fp = y_true_tra[ind_class_fp_tra]\n",
+ "y_choice_tra_fp = y_choice_tra[ind_class_fp_tra]\n",
+ "\n",
+ "x_tra_tn = x_tra[ind_class_tn_tra]\n",
+ "y_true_tra_tn = y_true_tra[ind_class_tn_tra]\n",
+ "y_choice_tra_tn = y_choice_tra[ind_class_tn_tra]\n",
+ "\n",
+ "x_tra_fn = x_tra[ind_class_fn_tra]\n",
+ "y_true_tra_fn = y_true_tra[ind_class_fn_tra]\n",
+ "y_choice_tra_fn = y_choice_tra[ind_class_fn_tra]\n",
+ "\n",
+ "n_tp = len(ind_class_tp_tra)\n",
+ "n_fp = len(ind_class_fp_tra)\n",
+ "n_tn = len(ind_class_tn_tra)\n",
+ "n_fn = len(ind_class_fn_tra)\n",
+ "\n",
+ "print(f\"TP count: {n_tp:4d}\")\n",
+ "print(f\"FP count: {n_fp:4d}\")\n",
+ "print(f\"TN count: {n_tn:4d}\")\n",
+ "print(f\"FN count: {n_fn:4d}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:40.728669Z",
+ "iopub.status.busy": "2025-02-09T23:13:40.728301Z",
+ "iopub.status.idle": "2025-02-09T23:13:45.386481Z",
+ "shell.execute_reply": "2025-02-09T23:13:45.385906Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:40.728632Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "if n_tp > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"ExampleImages_TruePostives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayImageConfusion(x_tra_tp,\n",
+ " y_true_tra_tp,\n",
+ " y_choice_tra_tp,\n",
+ " title_main=\"True Positives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_fp > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"ExampleImages_FalsePostives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayImageConfusion(x_tra_fp,\n",
+ " y_true_tra_fp,\n",
+ " y_choice_tra_fp,\n",
+ " title_main=\"False Positives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_tn > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"ExampleImages_TrueNegatives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayImageConfusion(x_tra_tn,\n",
+ " y_true_tra_tn,\n",
+ " y_choice_tra_tn,\n",
+ " title_main=\"True Negatives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_fn > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix'] + \"_\"\\\n",
+ " + \"ExampleImages_FalseNegatives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayImageConfusion(x_tra_fn,\n",
+ " y_true_tra_fn,\n",
+ " y_choice_tra_fn,\n",
+ " title_main=\"False Negatives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 8: Four panels of 10 images each, representing true positives (top), false positives (second), true negatives (third), and false negatives (bottom), for classification category 2.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot histograms of images pixels of true positives, false positives, true negatives, and false negatives."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 50,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:45.387543Z",
+ "iopub.status.busy": "2025-02-09T23:13:45.387337Z",
+ "iopub.status.idle": "2025-02-09T23:13:50.807474Z",
+ "shell.execute_reply": "2025-02-09T23:13:50.806324Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:45.387525Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "if n_tp > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix']\\\n",
+ " + \"_\" + \"ExampleImages_TruePostives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayHistogramConfusion(x_tra_tp,\n",
+ " y_true_tra_tp,\n",
+ " y_choice_tra_tp,\n",
+ " title_main=\"True Positives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_fp > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix']\\\n",
+ " + \"_\" + \"ExampleImages_FalsePostives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayHistogramConfusion(x_tra_fp,\n",
+ " y_true_tra_fp,\n",
+ " y_choice_tra_fp,\n",
+ " title_main=\"False Positives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_tn > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix']\\\n",
+ " + \"_\" + \"ExampleImages_TrueNegatives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayHistogramConfusion(x_tra_tn,\n",
+ " y_true_tra_tn,\n",
+ " y_choice_tra_tn,\n",
+ " title_main=\"True Negatives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])\n",
+ "\n",
+ "if n_fn > 0:\n",
+ " file_prefix = path_dict['file_figure_prefix']\\\n",
+ " + \"_\" + \"ExampleImages_FalseNegatives_on_class_\"\\\n",
+ " + str(class_value) + \"_\" + path_dict['run_label']\n",
+ " plotArrayHistogramConfusion(x_tra_fn,\n",
+ " y_true_tra_fn,\n",
+ " y_choice_tra_fn,\n",
+ " title_main=\"False Negatives\",\n",
+ " num=10,\n",
+ " file_prefix=file_prefix,\n",
+ " file_location=path_dict['dir_data_figures'],\n",
+ " file_suffix=path_dict['file_figure_suffix'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 9: Histograms of the pixel flux values for the images shown in Figure 7. Here, it is difficult to infer reasons for network classification errors like false positives and false negatives. If it is expected that a particular would have a particular distribution of pixel brightnesses, but the histogram for the predicted digit has a different distribution, that may help you identify a pattern. This may be a more useful diagnostic for physics-related data, like galaxy light profiles."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 4.3.4.3. Investigate morphological features of images with feature maps"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The last diagnostic of this tutorial is the `feature map`, which shows how that one layer affects or processes the input image after the weights have been set during the model training. The morphological elements in the feature maps are \n",
+ "\n",
+ "The feature map is obtained by passing an input image through one layer of the trained neural network model. More specifically, we first choose an input datum (image) by setting the array index of interest. Then, we make the image a `tensor` with the appropriate number of dimensions using `unsqueeze`. Then, we transfer it to the `device`. Next, we choose the `conv1` layer defined the section above where we defined the `model`. Finally, we use the number of convolutional kernels (32) that exist in the `conv1` layer of the model. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 51,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2025-02-09T23:13:50.809088Z",
+ "iopub.status.busy": "2025-02-09T23:13:50.808737Z",
+ "iopub.status.idle": "2025-02-09T23:13:55.192290Z",
+ "shell.execute_reply": "2025-02-09T23:13:55.191710Z",
+ "shell.execute_reply.started": "2025-02-09T23:13:50.809054Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "index_input_image = 1\n",
+ "number_conv_kernels = 32\n",
+ "\n",
+ "input_image = dataset.data[index_input_image].type(torch.float32)\n",
+ "input_image = input_image.clone().detach()\n",
+ "input_image = input_image.unsqueeze(0)\n",
+ "input_image = input_image.to(device)\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " feature_maps = model.conv1(input_image).cpu()\n",
+ "\n",
+ "fig, ax = plt.subplots(4, 8, sharex=True, sharey=True, figsize=(16, 8))\n",
+ "\n",
+ "for i in range(0, number_conv_kernels):\n",
+ " row, col = i//8, i%8\n",
+ " ax[row][col].imshow(feature_maps[i])\n",
+ "\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> Figure 10: `Feature maps` of one input from the `conv1` layer of the `model`. The units on the x- and y-axes are the pixel indices. There are 32 images because there are 32 convolutional kernels in the `conv1` layer.\n",
+ ">\n",
+ "> The feature maps derived from the trained model show which morphological features --- e.g., lines and edges --- are favored by the model. When the model is accurate, these features and feature maps will accurately reflect the input image. In this example, the handwritten digit is \"0,\" and the feature maps all appear as \"0's\". Some maps have more clearly defined edges and \"0\"-like features than others, because each convolutional kernel has a distinct weight parameter associated with it. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 5. Exercises for the Learner"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Each time you train a new model, re-run all the diagnostic plots.\n",
+ "\n",
+ "1. How do the loss and accuracy histories change when batch size is small or large? Why?\n",
+ "2. Does the NN take more or less time (more or fewer epochs) to converge if the input image data are normalized or not normalized? Why?\n",
+ "3. How does the size of the training set affect the model's accuracy and loss -- keeping the number of epochs the same? Why?\n",
+ "3. How does the random seed for the weight initialization affect the model's accuracy and loss -- keeping the number of epochs the same?\n",
+ "5. Use the `time` module to estimate the time for the model fitting. Record that time. Increase and then decrease the number of weights in the NN by an order of magnitude. Train the NN for each of those models and record the times. How does the number of weights in the neural network affect the training time and model loss and accuracy?\n",
+ "6. Use the `time` module to estimate the time for the model fitting. Record that time. Increase and then decrease the number of layers in the NN. Train the NN for each of those models and record the times. How does the number of weights in the neural network affect the training time and model loss and accuracy?\n",
+ "7. Use the `time` module to estimate the time for the model fitting. Record that time. Add a convolutional layer to the NN. Train the NN for each of those models and record the times. How does the number of weights in the neural network affect the training time and model loss and accuracy?"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "machine_shape": "hm",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/README.md b/README.md
index 8ec5f12..9307751 100644
--- a/README.md
+++ b/README.md
@@ -36,6 +36,7 @@ Tutorial titles in **bold** have Spanish-language versions.
| 13a. Using The Image Cutout Tool With DP0.2 | Demonstration of the use of the image cutout tool with a few science applications. |
| 14. Injecting Synthetic Sources Into Single-Visit Images | Inject artificial stars and galaxies into images. |
| 15. Survey Property Maps | Use the tools to visualize full-area survey property maps. |
+| 16a. Introduction to Tensorflow | Learn to classify images with AI-based classification algorithms. |
## DP0.3 Tutorials
@@ -119,7 +120,7 @@ The *content* of these notebooks are licensed under the Apache 2.0 License. Tha
| **13a. Using The Image Cutout Tool With DP0.2**
**Level:** Beginner
**Description:** This notebook demonstrates how to use the Rubin Image Cutout Service.
**Skills:** Run the Rubin Image Cutout Service for visual inspection of small cutouts of LSST images.
**Data Products:** Images (deepCoadd, calexp), catalogs (objectTable, diaObject, truthTables, ivoa.ObsCore).
**Packages:** PyVO, lsst.rsp.get_tap_service, lsst.pipe.tasks.registerImage, lsst.afw.display
|
| **14. Injecting Synthetic Sources Into Single-Visit Images**
**Level:** Advanced
**Description:** This tutorial demonstrates a method to inject artificial sources (stars and galaxies) into calexp images using the measured point-spread function of the given calexp image. Confirmation that the synthetic sources were correctly injected into the image is done by running a difference imaging task from the pipelines.
**Skills:** Use the `source_injection` tools to inject synthetic sources into images. Create a difference image from a `calexp` with injected sources.
**Data Products:** Butler calexp images and corresponding src catalogs, goodSeeingDiff_templateExp images, and injection_catalogs.
**Packages:** lsst.source.injection
|
| **15. Survey Property Maps**
**Level:** Intermediate
**Description:** Use the tools to visualize full-area survey property maps.
**Skills:** Load and visualize survey property maps using healsparse and skyproj.
**Data Products:** Survey property maps.
**Packages:** healsparse, skyproj, lsst.daf.butler
|
-
+| **16a. Introduction to Tensorflow**
**Level:** Beginner
**Description:** An introduction to the classification of images with AI-based classification algorithms.
**Skills:** Examine AI training data, prepare it for a classification task, perform classification with a neural network, and examine the diagnostics of the classification task.
**Data Products:** MNIST data.
**Packages:** sklearn, tensorflow
|
| Skills in **DP0.3** Tutorial Notebooks |
|---|