1
1
{
2
- "cells" : [
3
- {
4
- "cell_type" : " code" ,
5
- "execution_count" : null ,
6
- "metadata" : {},
7
- "outputs" : [],
8
- "source" : [
9
- " import torch\n " ,
10
- " import lightning.pytorch as pl\n " ,
11
- " from torch.utils.data import DataLoader, Subset\n " ,
12
- " from torchgeo.datasets import VHR10\n " ,
13
- " from torchvision.transforms.functional import to_pil_image\n " ,
14
- " from matplotlib.patches import Rectangle\n " ,
15
- " import matplotlib.pyplot as plt\n " ,
16
- " import torch.nn.functional as F\n " ,
17
- " from torchgeo.trainers import InstanceSegmentationTask \n " ,
18
- " import matplotlib.patches as patches\n " ,
19
- " import numpy as np\n " ,
20
- " \n " ,
21
- " def collate_fn(batch):\n " ,
22
- " \"\"\" Custom collate function for DataLoader.\"\"\"\n " ,
23
- " max_height = max(sample['image'].shape[1] for sample in batch)\n " ,
24
- " max_width = max(sample['image'].shape[2] for sample in batch)\n " ,
25
- " \n " ,
26
- " images = torch.stack([\n " ,
27
- " F.pad(sample['image'], (0, max_width - sample['image'].shape[2], 0, max_height - sample['image'].shape[1]))\n " ,
28
- " for sample in batch\n " ,
29
- " ])\n " ,
30
- " \n " ,
31
- " targets = [\n " ,
32
- " {\n " ,
33
- " \" labels\" : sample[\" labels\" ].to(torch.int64),\n " ,
34
- " \" boxes\" : sample[\" boxes\" ].to(torch.float32),\n " ,
35
- " \" masks\" : F.pad(\n " ,
36
- " sample[\" masks\" ],\n " ,
37
- " (0, max_width - sample[\" masks\" ].shape[2], 0, max_height - sample[\" masks\" ].shape[1]),\n " ,
38
- " ).to(torch.uint8),\n " ,
39
- " }\n " ,
40
- " for sample in batch\n " ,
41
- " ]\n " ,
42
- " \n " ,
43
- " return {\" image\" : images, \" target\" : targets}\n " ,
44
- " \n " ,
45
- " def visualize_predictions(image, predictions, targets):\n " ,
46
- " \"\"\" Visualize predictions and ground truth.\"\"\"\n " ,
47
- " image = to_pil_image(image)\n " ,
48
- " \n " ,
49
- " fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n " ,
50
- " ax.imshow(image)\n " ,
51
- " \n " ,
52
- " # Predictions\n " ,
53
- " for box, label in zip(predictions['boxes'], predictions['labels']):\n " ,
54
- " x1, y1, x2, y2 = box\n " ,
55
- " rect = Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='red', facecolor='none')\n " ,
56
- " ax.add_patch(rect)\n " ,
57
- " ax.text(x1, y1, f\" Pred: {label.item()}\" , color='red', fontsize=12)\n " ,
58
- " \n " ,
59
- " # Ground truth\n " ,
60
- " for box, label in zip(targets['boxes'], targets['labels']):\n " ,
61
- " x1, y1, x2, y2 = box\n " ,
62
- " rect = Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='blue', facecolor='none')\n " ,
63
- " ax.add_patch(rect)\n " ,
64
- " ax.text(x1, y1, f\" GT: {label.item()}\" , color='blue', fontsize=12)\n " ,
65
- " \n " ,
66
- " plt.show()\n " ,
67
- " \n " ,
68
- " def plot_losses(train_losses, val_losses):\n " ,
69
- " \"\"\" Plot training and validation losses over epochs.\"\"\"\n " ,
70
- " plt.figure(figsize=(10, 5))\n " ,
71
- " plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss', marker='o')\n " ,
72
- " plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss', marker='s')\n " ,
73
- " plt.xlabel('Epochs')\n " ,
74
- " plt.ylabel('Loss')\n " ,
75
- " plt.title('Training and Validation Loss Over Epochs')\n " ,
76
- " plt.legend()\n " ,
77
- " plt.grid()\n " ,
78
- " plt.show()\n " ,
79
- " \n " ,
80
- " # Initialize VHR-10 dataset\n " ,
81
- " train_dataset = VHR10(root=\" data\" , split=\" positive\" , transforms=None, download=True)\n " ,
82
- " val_dataset = VHR10(root=\" data\" , split=\" positive\" , transforms=None)\n " ,
83
- " \n " ,
84
- " # Subset for quick experimentation (adjust N as needed)\n " ,
85
- " N = 100\n " ,
86
- " train_subset = Subset(train_dataset, list(range(N)))\n " ,
87
- " val_subset = Subset(val_dataset, list(range(N)))\n " ,
88
- " \n " ,
89
- " if __name__ == '__main__':\n " ,
90
- " import multiprocessing\n " ,
91
- " multiprocessing.set_start_method('spawn', force=True)\n " ,
92
- " \n " ,
93
- " train_loader = DataLoader(train_subset, batch_size=8, shuffle=True, num_workers=1, collate_fn=collate_fn)\n " ,
94
- " val_loader = DataLoader(val_subset, batch_size=8, shuffle=False, num_workers=1, collate_fn=collate_fn)\n " ,
95
- " \n " ,
96
- " # Trainer setup\n " ,
97
- " trainer = pl.Trainer(\n " ,
98
- " max_epochs=5, \n " ,
99
- " accelerator=\" gpu\" if torch.cuda.is_available() else \" cpu\" ,\n " ,
100
- " devices=1\n " ,
101
- " )\n " ,
102
- " \n " ,
103
- " task = InstanceSegmentationTask(\n " ,
104
- " model=\" mask_rcnn\" , \n " ,
105
- " backbone=\" resnet50\" , \n " ,
106
- " weights=\" imagenet\" , # Pretrained on ImageNet\n " ,
107
- " num_classes=11, # VHR-10 has 10 classes + 1 background\n " ,
108
- " lr=1e-3, \n " ,
109
- " freeze_backbone=False \n " ,
110
- " )\n " ,
111
- " \n " ,
112
- " print('\\ nSTART TRAINING\\ n')\n " ,
113
- " # trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader)\n " ,
114
- " train_losses, val_losses = [], []\n " ,
115
- " for epoch in range(5):\n " ,
116
- " trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader)\n " ,
117
- " train_loss = task.trainer.callback_metrics.get(\" train_loss\" )\n " ,
118
- " val_loss = task.trainer.callback_metrics.get(\" val_loss\" )\n " ,
119
- " if train_loss is not None:\n " ,
120
- " train_losses.append(train_loss.item())\n " ,
121
- " if val_loss is not None:\n " ,
122
- " val_losses.append(val_loss.item())\n " ,
123
- " \n " ,
124
- " plot_losses(train_losses, val_losses)\n " ,
125
- " \n " ,
126
- " #trainer.test(task, dataloaders=val_loader)\n " ,
127
- " \n " ,
128
- " # Inference and Visualization\n " ,
129
- " sample = train_dataset[1]\n " ,
130
- " image = sample['image'].unsqueeze(0) \n " ,
131
- " predictions = task.predict_step({\" image\" : image}, batch_idx=0)\n " ,
132
- " visualize_predictions(image[0], predictions[0], sample)\n " ,
133
- " \n "
134
- ]
135
- }
136
- ],
137
- "metadata" : {
138
- "kernelspec" : {
139
- "display_name" : " Python 3" ,
140
- "language" : " python" ,
141
- "name" : " python3"
2
+ "cells" : [
3
+ {
4
+ "cell_type" : " code" ,
5
+ "execution_count" : 1 ,
6
+ "metadata" : {
7
+ "id" : " gQBpL3DTHh2v" ,
8
+ "outputId" : " 0f96c780-21d7-42fe-fc97-db77155f826c" ,
9
+ "colab" : {
10
+ "base_uri" : " https://localhost:8080/" ,
11
+ "height" : 383
12
+ }
13
+ },
14
+ "outputs" : [
15
+ {
16
+ "output_type" : " error" ,
17
+ "ename" : " ModuleNotFoundError" ,
18
+ "evalue" : " No module named 'lightning'" ,
19
+ "traceback" : [
20
+ " \u001b [0;31m---------------------------------------------------------------------------\u001b [0m" ,
21
+ " \u001b [0;31mModuleNotFoundError\u001b [0m Traceback (most recent call last)" ,
22
+ "\u001b[0;32m<ipython-input-1-830d39adcfdd>\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mlightning\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpytorch\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mpl\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSubset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorchgeo\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdatasets\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mVHR10\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorchvision\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunctional\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mto_pil_image\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
23
+ " \u001b [0;31mModuleNotFoundError\u001b [0m: No module named 'lightning'" ,
24
+ " " ,
25
+ " \u001b [0;31m---------------------------------------------------------------------------\u001b [0;32m\n NOTE: If your import is failing due to a missing package, you can\n manually install dependencies using either !pip or !apt.\n\n To view examples of installing some common dependencies, click the\n\" Open Examples\" button below.\n \u001b [0;31m---------------------------------------------------------------------------\u001b [0m\n "
26
+ ],
27
+ "errorDetails" : {
28
+ "actions" : [
29
+ {
30
+ "action" : " open_url" ,
31
+ "actionText" : " Open Examples" ,
32
+ "url" : " /notebooks/snippets/importing_libraries.ipynb"
33
+ }
34
+ ]
35
+ }
36
+ }
37
+ ],
38
+ "source" : [
39
+ " !pip install torch torchvision torchgeo lightning matplotlib\n " ,
40
+ " \n " ,
41
+ " import torch\n " ,
42
+ " import lightning.pytorch as pl\n " ,
43
+ " from torch.utils.data import DataLoader, Subset\n " ,
44
+ " from torchgeo.datasets import VHR10\n " ,
45
+ " from torchvision.transforms.functional import to_pil_image\n " ,
46
+ " from matplotlib.patches import Rectangle\n " ,
47
+ " import matplotlib.pyplot as plt\n " ,
48
+ " import torch.nn.functional as F\n " ,
49
+ " from torchgeo.trainers import InstanceSegmentationTask\n " ,
50
+ " import matplotlib.patches as patches\n " ,
51
+ " import numpy as np\n " ,
52
+ " \n " ,
53
+ " def collate_fn(batch):\n " ,
54
+ " \"\"\" Custom collate function for DataLoader.\"\"\"\n " ,
55
+ " max_height = max(sample['image'].shape[1] for sample in batch)\n " ,
56
+ " max_width = max(sample['image'].shape[2] for sample in batch)\n " ,
57
+ " \n " ,
58
+ " images = torch.stack([\n " ,
59
+ " F.pad(sample['image'], (0, max_width - sample['image'].shape[2], 0, max_height - sample['image'].shape[1]))\n " ,
60
+ " for sample in batch\n " ,
61
+ " ])\n " ,
62
+ " \n " ,
63
+ " targets = [\n " ,
64
+ " {\n " ,
65
+ " \" labels\" : sample[\" labels\" ].to(torch.int64),\n " ,
66
+ " \" boxes\" : sample[\" boxes\" ].to(torch.float32),\n " ,
67
+ " \" masks\" : F.pad(\n " ,
68
+ " sample[\" masks\" ],\n " ,
69
+ " (0, max_width - sample[\" masks\" ].shape[2], 0, max_height - sample[\" masks\" ].shape[1]),\n " ,
70
+ " ).to(torch.uint8),\n " ,
71
+ " }\n " ,
72
+ " for sample in batch\n " ,
73
+ " ]\n " ,
74
+ " \n " ,
75
+ " return {\" image\" : images, \" target\" : targets}\n " ,
76
+ " \n " ,
77
+ " def visualize_predictions(image, predictions, targets):\n " ,
78
+ " \"\"\" Visualize predictions and ground truth.\"\"\"\n " ,
79
+ " image = to_pil_image(image)\n " ,
80
+ " \n " ,
81
+ " fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n " ,
82
+ " ax.imshow(image)\n " ,
83
+ " \n " ,
84
+ " # Predictions\n " ,
85
+ " for box, label in zip(predictions['boxes'], predictions['labels']):\n " ,
86
+ " x1, y1, x2, y2 = box\n " ,
87
+ " rect = Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='red', facecolor='none')\n " ,
88
+ " ax.add_patch(rect)\n " ,
89
+ " ax.text(x1, y1, f\" Pred: {label.item()}\" , color='red', fontsize=12)\n " ,
90
+ " \n " ,
91
+ " # Ground truth\n " ,
92
+ " for box, label in zip(targets['boxes'], targets['labels']):\n " ,
93
+ " x1, y1, x2, y2 = box\n " ,
94
+ " rect = Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='blue', facecolor='none')\n " ,
95
+ " ax.add_patch(rect)\n " ,
96
+ " ax.text(x1, y1, f\" GT: {label.item()}\" , color='blue', fontsize=12)\n " ,
97
+ " \n " ,
98
+ " plt.show()\n " ,
99
+ " \n " ,
100
+ " def plot_losses(train_losses, val_losses):\n " ,
101
+ " \"\"\" Plot training and validation losses over epochs.\"\"\"\n " ,
102
+ " plt.figure(figsize=(10, 5))\n " ,
103
+ " plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss', marker='o')\n " ,
104
+ " plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss', marker='s')\n " ,
105
+ " plt.xlabel('Epochs')\n " ,
106
+ " plt.ylabel('Loss')\n " ,
107
+ " plt.title('Training and Validation Loss Over Epochs')\n " ,
108
+ " plt.legend()\n " ,
109
+ " plt.grid()\n " ,
110
+ " plt.show()\n " ,
111
+ " \n " ,
112
+ " # Initialize VHR-10 dataset\n " ,
113
+ " train_dataset = VHR10(root=\" data\" , split=\" positive\" , transforms=None, download=True)\n " ,
114
+ " val_dataset = VHR10(root=\" data\" , split=\" positive\" , transforms=None)\n " ,
115
+ " \n " ,
116
+ " # Subset for quick experimentation (adjust N as needed)\n " ,
117
+ " N = 100\n " ,
118
+ " train_subset = Subset(train_dataset, list(range(N)))\n " ,
119
+ " val_subset = Subset(val_dataset, list(range(N)))\n " ,
120
+ " \n " ,
121
+ " if __name__ == '__main__':\n " ,
122
+ " import multiprocessing\n " ,
123
+ " multiprocessing.set_start_method('spawn', force=True)\n " ,
124
+ " \n " ,
125
+ " train_loader = DataLoader(train_subset, batch_size=8, shuffle=True, num_workers=1, collate_fn=collate_fn)\n " ,
126
+ " val_loader = DataLoader(val_subset, batch_size=8, shuffle=False, num_workers=1, collate_fn=collate_fn)\n " ,
127
+ " \n " ,
128
+ " # Trainer setup\n " ,
129
+ " trainer = pl.Trainer(\n " ,
130
+ " max_epochs=5,\n " ,
131
+ " accelerator=\" gpu\" if torch.cuda.is_available() else \" cpu\" ,\n " ,
132
+ " devices=1\n " ,
133
+ " )\n " ,
134
+ " \n " ,
135
+ " task = InstanceSegmentationTask(\n " ,
136
+ " model=\" mask_rcnn\" ,\n " ,
137
+ " backbone=\" resnet50\" ,\n " ,
138
+ " weights=\" imagenet\" , # Pretrained on ImageNet\n " ,
139
+ " num_classes=11, # VHR-10 has 10 classes + 1 background\n " ,
140
+ " lr=1e-3,\n " ,
141
+ " freeze_backbone=False\n " ,
142
+ " )\n " ,
143
+ " \n " ,
144
+ " print('\\ nSTART TRAINING\\ n')\n " ,
145
+ " # trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader)\n " ,
146
+ " train_losses, val_losses = [], []\n " ,
147
+ " for epoch in range(5):\n " ,
148
+ " trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader)\n " ,
149
+ " train_loss = task.trainer.callback_metrics.get(\" train_loss\" )\n " ,
150
+ " val_loss = task.trainer.callback_metrics.get(\" val_loss\" )\n " ,
151
+ " if train_loss is not None:\n " ,
152
+ " train_losses.append(train_loss.item())\n " ,
153
+ " if val_loss is not None:\n " ,
154
+ " val_losses.append(val_loss.item())\n " ,
155
+ " \n " ,
156
+ " plot_losses(train_losses, val_losses)\n " ,
157
+ " \n " ,
158
+ " #trainer.test(task, dataloaders=val_loader)\n " ,
159
+ " \n " ,
160
+ " # Inference and Visualization\n " ,
161
+ " sample = train_dataset[1]\n " ,
162
+ " image = sample['image'].unsqueeze(0)\n " ,
163
+ " predictions = task.predict_step({\" image\" : image}, batch_idx=0)\n " ,
164
+ " visualize_predictions(image[0], predictions[0], sample)\n " ,
165
+ " \n "
166
+ ]
167
+ }
168
+ ],
169
+ "metadata" : {
170
+ "kernelspec" : {
171
+ "display_name" : " Python 3" ,
172
+ "name" : " python3"
173
+ },
174
+ "language_info" : {
175
+ "name" : " python" ,
176
+ "version" : " 3.12.0"
177
+ },
178
+ "colab" : {
179
+ "provenance" : [],
180
+ "gpuType" : " T4"
181
+ },
182
+ "accelerator" : " GPU"
142
183
},
143
- "language_info" : {
144
- "name" : " python" ,
145
- "version" : " 3.12.0"
146
- }
147
- },
148
- "nbformat" : 4 ,
149
- "nbformat_minor" : 2
150
- }
184
+ "nbformat" : 4 ,
185
+ "nbformat_minor" : 0
186
+ }
0 commit comments