|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "code", |
| 5 | + "execution_count": null, |
| 6 | + "metadata": {}, |
| 7 | + "outputs": [], |
| 8 | + "source": [ |
| 9 | + "import torch\n", |
| 10 | + "import lightning.pytorch as pl\n", |
| 11 | + "from torch.utils.data import DataLoader, Subset\n", |
| 12 | + "from torchgeo.datasets import VHR10\n", |
| 13 | + "from torchvision.transforms.functional import to_pil_image\n", |
| 14 | + "from matplotlib.patches import Rectangle\n", |
| 15 | + "import matplotlib.pyplot as plt\n", |
| 16 | + "import torch.nn.functional as F\n", |
| 17 | + "from torchgeo.trainers import InstanceSegmentationTask \n", |
| 18 | + "import matplotlib.patches as patches\n", |
| 19 | + "import numpy as np\n", |
| 20 | + "\n", |
| 21 | + "def collate_fn(batch):\n", |
| 22 | + " \"\"\"Custom collate function for DataLoader.\"\"\"\n", |
| 23 | + " max_height = max(sample['image'].shape[1] for sample in batch)\n", |
| 24 | + " max_width = max(sample['image'].shape[2] for sample in batch)\n", |
| 25 | + "\n", |
| 26 | + " images = torch.stack([\n", |
| 27 | + " F.pad(sample['image'], (0, max_width - sample['image'].shape[2], 0, max_height - sample['image'].shape[1]))\n", |
| 28 | + " for sample in batch\n", |
| 29 | + " ])\n", |
| 30 | + "\n", |
| 31 | + " targets = [\n", |
| 32 | + " {\n", |
| 33 | + " \"labels\": sample[\"labels\"].to(torch.int64),\n", |
| 34 | + " \"boxes\": sample[\"boxes\"].to(torch.float32),\n", |
| 35 | + " \"masks\": F.pad(\n", |
| 36 | + " sample[\"masks\"],\n", |
| 37 | + " (0, max_width - sample[\"masks\"].shape[2], 0, max_height - sample[\"masks\"].shape[1]),\n", |
| 38 | + " ).to(torch.uint8),\n", |
| 39 | + " }\n", |
| 40 | + " for sample in batch\n", |
| 41 | + " ]\n", |
| 42 | + "\n", |
| 43 | + " return {\"image\": images, \"target\": targets}\n", |
| 44 | + "\n", |
| 45 | + "def visualize_predictions(image, predictions, targets):\n", |
| 46 | + " \"\"\"Visualize predictions and ground truth.\"\"\"\n", |
| 47 | + " image = to_pil_image(image)\n", |
| 48 | + "\n", |
| 49 | + " fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n", |
| 50 | + " ax.imshow(image)\n", |
| 51 | + "\n", |
| 52 | + " # Predictions\n", |
| 53 | + " for box, label in zip(predictions['boxes'], predictions['labels']):\n", |
| 54 | + " x1, y1, x2, y2 = box\n", |
| 55 | + " rect = Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='red', facecolor='none')\n", |
| 56 | + " ax.add_patch(rect)\n", |
| 57 | + " ax.text(x1, y1, f\"Pred: {label.item()}\", color='red', fontsize=12)\n", |
| 58 | + "\n", |
| 59 | + " # Ground truth\n", |
| 60 | + " for box, label in zip(targets['boxes'], targets['labels']):\n", |
| 61 | + " x1, y1, x2, y2 = box\n", |
| 62 | + " rect = Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='blue', facecolor='none')\n", |
| 63 | + " ax.add_patch(rect)\n", |
| 64 | + " ax.text(x1, y1, f\"GT: {label.item()}\", color='blue', fontsize=12)\n", |
| 65 | + "\n", |
| 66 | + " plt.show()\n", |
| 67 | + "\n", |
| 68 | + "def plot_losses(train_losses, val_losses):\n", |
| 69 | + " \"\"\"Plot training and validation losses over epochs.\"\"\"\n", |
| 70 | + " plt.figure(figsize=(10, 5))\n", |
| 71 | + " plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss', marker='o')\n", |
| 72 | + " plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss', marker='s')\n", |
| 73 | + " plt.xlabel('Epochs')\n", |
| 74 | + " plt.ylabel('Loss')\n", |
| 75 | + " plt.title('Training and Validation Loss Over Epochs')\n", |
| 76 | + " plt.legend()\n", |
| 77 | + " plt.grid()\n", |
| 78 | + " plt.show()\n", |
| 79 | + "\n", |
| 80 | + "# Initialize VHR-10 dataset\n", |
| 81 | + "train_dataset = VHR10(root=\"data\", split=\"positive\", transforms=None, download=True)\n", |
| 82 | + "val_dataset = VHR10(root=\"data\", split=\"positive\", transforms=None)\n", |
| 83 | + "\n", |
| 84 | + "# Subset for quick experimentation (adjust N as needed)\n", |
| 85 | + "N = 100\n", |
| 86 | + "train_subset = Subset(train_dataset, list(range(N)))\n", |
| 87 | + "val_subset = Subset(val_dataset, list(range(N)))\n", |
| 88 | + "\n", |
| 89 | + "if __name__ == '__main__':\n", |
| 90 | + " import multiprocessing\n", |
| 91 | + " multiprocessing.set_start_method('spawn', force=True)\n", |
| 92 | + "\n", |
| 93 | + " train_loader = DataLoader(train_subset, batch_size=8, shuffle=True, num_workers=1, collate_fn=collate_fn)\n", |
| 94 | + " val_loader = DataLoader(val_subset, batch_size=8, shuffle=False, num_workers=1, collate_fn=collate_fn)\n", |
| 95 | + "\n", |
| 96 | + " # Trainer setup\n", |
| 97 | + " trainer = pl.Trainer(\n", |
| 98 | + " max_epochs=5, \n", |
| 99 | + " accelerator=\"gpu\" if torch.cuda.is_available() else \"cpu\",\n", |
| 100 | + " devices=1\n", |
| 101 | + " )\n", |
| 102 | + "\n", |
| 103 | + " task = InstanceSegmentationTask(\n", |
| 104 | + " model=\"mask_rcnn\", \n", |
| 105 | + " backbone=\"resnet50\", \n", |
| 106 | + " weights=\"imagenet\", # Pretrained on ImageNet\n", |
| 107 | + " num_classes=11, # VHR-10 has 10 classes + 1 background\n", |
| 108 | + " lr=1e-3, \n", |
| 109 | + " freeze_backbone=False \n", |
| 110 | + " )\n", |
| 111 | + "\n", |
| 112 | + " print('\\nSTART TRAINING\\n')\n", |
| 113 | + " # trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader)\n", |
| 114 | + " train_losses, val_losses = [], []\n", |
| 115 | + " for epoch in range(5):\n", |
| 116 | + " trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader)\n", |
| 117 | + " train_loss = task.trainer.callback_metrics.get(\"train_loss\")\n", |
| 118 | + " val_loss = task.trainer.callback_metrics.get(\"val_loss\")\n", |
| 119 | + " if train_loss is not None:\n", |
| 120 | + " train_losses.append(train_loss.item())\n", |
| 121 | + " if val_loss is not None:\n", |
| 122 | + " val_losses.append(val_loss.item())\n", |
| 123 | + " \n", |
| 124 | + " plot_losses(train_losses, val_losses)\n", |
| 125 | + "\n", |
| 126 | + " #trainer.test(task, dataloaders=val_loader)\n", |
| 127 | + "\n", |
| 128 | + " # Inference and Visualization\n", |
| 129 | + " sample = train_dataset[1]\n", |
| 130 | + " image = sample['image'].unsqueeze(0) \n", |
| 131 | + " predictions = task.predict_step({\"image\": image}, batch_idx=0)\n", |
| 132 | + " visualize_predictions(image[0], predictions[0], sample)\n", |
| 133 | + "\n" |
| 134 | + ] |
| 135 | + } |
| 136 | + ], |
| 137 | + "metadata": { |
| 138 | + "kernelspec": { |
| 139 | + "display_name": "Python 3", |
| 140 | + "language": "python", |
| 141 | + "name": "python3" |
| 142 | + }, |
| 143 | + "language_info": { |
| 144 | + "name": "python", |
| 145 | + "version": "3.12.0" |
| 146 | + } |
| 147 | + }, |
| 148 | + "nbformat": 4, |
| 149 | + "nbformat_minor": 2 |
| 150 | +} |
0 commit comments