Skip to content

Commit b3de001

Browse files
committed
Creato con Colab
1 parent 70074e7 commit b3de001

File tree

1 file changed

+184
-148
lines changed

1 file changed

+184
-148
lines changed

test_trainer.ipynb

+184-148
Original file line numberDiff line numberDiff line change
@@ -1,150 +1,186 @@
11
{
2-
"cells": [
3-
{
4-
"cell_type": "code",
5-
"execution_count": null,
6-
"metadata": {},
7-
"outputs": [],
8-
"source": [
9-
"import torch\n",
10-
"import lightning.pytorch as pl\n",
11-
"from torch.utils.data import DataLoader, Subset\n",
12-
"from torchgeo.datasets import VHR10\n",
13-
"from torchvision.transforms.functional import to_pil_image\n",
14-
"from matplotlib.patches import Rectangle\n",
15-
"import matplotlib.pyplot as plt\n",
16-
"import torch.nn.functional as F\n",
17-
"from torchgeo.trainers import InstanceSegmentationTask \n",
18-
"import matplotlib.patches as patches\n",
19-
"import numpy as np\n",
20-
"\n",
21-
"def collate_fn(batch):\n",
22-
" \"\"\"Custom collate function for DataLoader.\"\"\"\n",
23-
" max_height = max(sample['image'].shape[1] for sample in batch)\n",
24-
" max_width = max(sample['image'].shape[2] for sample in batch)\n",
25-
"\n",
26-
" images = torch.stack([\n",
27-
" F.pad(sample['image'], (0, max_width - sample['image'].shape[2], 0, max_height - sample['image'].shape[1]))\n",
28-
" for sample in batch\n",
29-
" ])\n",
30-
"\n",
31-
" targets = [\n",
32-
" {\n",
33-
" \"labels\": sample[\"labels\"].to(torch.int64),\n",
34-
" \"boxes\": sample[\"boxes\"].to(torch.float32),\n",
35-
" \"masks\": F.pad(\n",
36-
" sample[\"masks\"],\n",
37-
" (0, max_width - sample[\"masks\"].shape[2], 0, max_height - sample[\"masks\"].shape[1]),\n",
38-
" ).to(torch.uint8),\n",
39-
" }\n",
40-
" for sample in batch\n",
41-
" ]\n",
42-
"\n",
43-
" return {\"image\": images, \"target\": targets}\n",
44-
"\n",
45-
"def visualize_predictions(image, predictions, targets):\n",
46-
" \"\"\"Visualize predictions and ground truth.\"\"\"\n",
47-
" image = to_pil_image(image)\n",
48-
"\n",
49-
" fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
50-
" ax.imshow(image)\n",
51-
"\n",
52-
" # Predictions\n",
53-
" for box, label in zip(predictions['boxes'], predictions['labels']):\n",
54-
" x1, y1, x2, y2 = box\n",
55-
" rect = Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='red', facecolor='none')\n",
56-
" ax.add_patch(rect)\n",
57-
" ax.text(x1, y1, f\"Pred: {label.item()}\", color='red', fontsize=12)\n",
58-
"\n",
59-
" # Ground truth\n",
60-
" for box, label in zip(targets['boxes'], targets['labels']):\n",
61-
" x1, y1, x2, y2 = box\n",
62-
" rect = Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='blue', facecolor='none')\n",
63-
" ax.add_patch(rect)\n",
64-
" ax.text(x1, y1, f\"GT: {label.item()}\", color='blue', fontsize=12)\n",
65-
"\n",
66-
" plt.show()\n",
67-
"\n",
68-
"def plot_losses(train_losses, val_losses):\n",
69-
" \"\"\"Plot training and validation losses over epochs.\"\"\"\n",
70-
" plt.figure(figsize=(10, 5))\n",
71-
" plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss', marker='o')\n",
72-
" plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss', marker='s')\n",
73-
" plt.xlabel('Epochs')\n",
74-
" plt.ylabel('Loss')\n",
75-
" plt.title('Training and Validation Loss Over Epochs')\n",
76-
" plt.legend()\n",
77-
" plt.grid()\n",
78-
" plt.show()\n",
79-
"\n",
80-
"# Initialize VHR-10 dataset\n",
81-
"train_dataset = VHR10(root=\"data\", split=\"positive\", transforms=None, download=True)\n",
82-
"val_dataset = VHR10(root=\"data\", split=\"positive\", transforms=None)\n",
83-
"\n",
84-
"# Subset for quick experimentation (adjust N as needed)\n",
85-
"N = 100\n",
86-
"train_subset = Subset(train_dataset, list(range(N)))\n",
87-
"val_subset = Subset(val_dataset, list(range(N)))\n",
88-
"\n",
89-
"if __name__ == '__main__':\n",
90-
" import multiprocessing\n",
91-
" multiprocessing.set_start_method('spawn', force=True)\n",
92-
"\n",
93-
" train_loader = DataLoader(train_subset, batch_size=8, shuffle=True, num_workers=1, collate_fn=collate_fn)\n",
94-
" val_loader = DataLoader(val_subset, batch_size=8, shuffle=False, num_workers=1, collate_fn=collate_fn)\n",
95-
"\n",
96-
" # Trainer setup\n",
97-
" trainer = pl.Trainer(\n",
98-
" max_epochs=5, \n",
99-
" accelerator=\"gpu\" if torch.cuda.is_available() else \"cpu\",\n",
100-
" devices=1\n",
101-
" )\n",
102-
"\n",
103-
" task = InstanceSegmentationTask(\n",
104-
" model=\"mask_rcnn\", \n",
105-
" backbone=\"resnet50\", \n",
106-
" weights=\"imagenet\", # Pretrained on ImageNet\n",
107-
" num_classes=11, # VHR-10 has 10 classes + 1 background\n",
108-
" lr=1e-3, \n",
109-
" freeze_backbone=False \n",
110-
" )\n",
111-
"\n",
112-
" print('\\nSTART TRAINING\\n')\n",
113-
" # trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader)\n",
114-
" train_losses, val_losses = [], []\n",
115-
" for epoch in range(5):\n",
116-
" trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader)\n",
117-
" train_loss = task.trainer.callback_metrics.get(\"train_loss\")\n",
118-
" val_loss = task.trainer.callback_metrics.get(\"val_loss\")\n",
119-
" if train_loss is not None:\n",
120-
" train_losses.append(train_loss.item())\n",
121-
" if val_loss is not None:\n",
122-
" val_losses.append(val_loss.item())\n",
123-
" \n",
124-
" plot_losses(train_losses, val_losses)\n",
125-
"\n",
126-
" #trainer.test(task, dataloaders=val_loader)\n",
127-
"\n",
128-
" # Inference and Visualization\n",
129-
" sample = train_dataset[1]\n",
130-
" image = sample['image'].unsqueeze(0) \n",
131-
" predictions = task.predict_step({\"image\": image}, batch_idx=0)\n",
132-
" visualize_predictions(image[0], predictions[0], sample)\n",
133-
"\n"
134-
]
135-
}
136-
],
137-
"metadata": {
138-
"kernelspec": {
139-
"display_name": "Python 3",
140-
"language": "python",
141-
"name": "python3"
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {
7+
"id": "gQBpL3DTHh2v",
8+
"outputId": "0f96c780-21d7-42fe-fc97-db77155f826c",
9+
"colab": {
10+
"base_uri": "https://localhost:8080/",
11+
"height": 383
12+
}
13+
},
14+
"outputs": [
15+
{
16+
"output_type": "error",
17+
"ename": "ModuleNotFoundError",
18+
"evalue": "No module named 'lightning'",
19+
"traceback": [
20+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
21+
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
22+
"\u001b[0;32m<ipython-input-1-830d39adcfdd>\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mlightning\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpytorch\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mpl\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSubset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorchgeo\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdatasets\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mVHR10\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorchvision\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunctional\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mto_pil_image\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
23+
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'lightning'",
24+
"",
25+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0;32m\nNOTE: If your import is failing due to a missing package, you can\nmanually install dependencies using either !pip or !apt.\n\nTo view examples of installing some common dependencies, click the\n\"Open Examples\" button below.\n\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n"
26+
],
27+
"errorDetails": {
28+
"actions": [
29+
{
30+
"action": "open_url",
31+
"actionText": "Open Examples",
32+
"url": "/notebooks/snippets/importing_libraries.ipynb"
33+
}
34+
]
35+
}
36+
}
37+
],
38+
"source": [
39+
"!pip install torch torchvision torchgeo lightning matplotlib\n",
40+
"\n",
41+
"import torch\n",
42+
"import lightning.pytorch as pl\n",
43+
"from torch.utils.data import DataLoader, Subset\n",
44+
"from torchgeo.datasets import VHR10\n",
45+
"from torchvision.transforms.functional import to_pil_image\n",
46+
"from matplotlib.patches import Rectangle\n",
47+
"import matplotlib.pyplot as plt\n",
48+
"import torch.nn.functional as F\n",
49+
"from torchgeo.trainers import InstanceSegmentationTask\n",
50+
"import matplotlib.patches as patches\n",
51+
"import numpy as np\n",
52+
"\n",
53+
"def collate_fn(batch):\n",
54+
" \"\"\"Custom collate function for DataLoader.\"\"\"\n",
55+
" max_height = max(sample['image'].shape[1] for sample in batch)\n",
56+
" max_width = max(sample['image'].shape[2] for sample in batch)\n",
57+
"\n",
58+
" images = torch.stack([\n",
59+
" F.pad(sample['image'], (0, max_width - sample['image'].shape[2], 0, max_height - sample['image'].shape[1]))\n",
60+
" for sample in batch\n",
61+
" ])\n",
62+
"\n",
63+
" targets = [\n",
64+
" {\n",
65+
" \"labels\": sample[\"labels\"].to(torch.int64),\n",
66+
" \"boxes\": sample[\"boxes\"].to(torch.float32),\n",
67+
" \"masks\": F.pad(\n",
68+
" sample[\"masks\"],\n",
69+
" (0, max_width - sample[\"masks\"].shape[2], 0, max_height - sample[\"masks\"].shape[1]),\n",
70+
" ).to(torch.uint8),\n",
71+
" }\n",
72+
" for sample in batch\n",
73+
" ]\n",
74+
"\n",
75+
" return {\"image\": images, \"target\": targets}\n",
76+
"\n",
77+
"def visualize_predictions(image, predictions, targets):\n",
78+
" \"\"\"Visualize predictions and ground truth.\"\"\"\n",
79+
" image = to_pil_image(image)\n",
80+
"\n",
81+
" fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
82+
" ax.imshow(image)\n",
83+
"\n",
84+
" # Predictions\n",
85+
" for box, label in zip(predictions['boxes'], predictions['labels']):\n",
86+
" x1, y1, x2, y2 = box\n",
87+
" rect = Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='red', facecolor='none')\n",
88+
" ax.add_patch(rect)\n",
89+
" ax.text(x1, y1, f\"Pred: {label.item()}\", color='red', fontsize=12)\n",
90+
"\n",
91+
" # Ground truth\n",
92+
" for box, label in zip(targets['boxes'], targets['labels']):\n",
93+
" x1, y1, x2, y2 = box\n",
94+
" rect = Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='blue', facecolor='none')\n",
95+
" ax.add_patch(rect)\n",
96+
" ax.text(x1, y1, f\"GT: {label.item()}\", color='blue', fontsize=12)\n",
97+
"\n",
98+
" plt.show()\n",
99+
"\n",
100+
"def plot_losses(train_losses, val_losses):\n",
101+
" \"\"\"Plot training and validation losses over epochs.\"\"\"\n",
102+
" plt.figure(figsize=(10, 5))\n",
103+
" plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss', marker='o')\n",
104+
" plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss', marker='s')\n",
105+
" plt.xlabel('Epochs')\n",
106+
" plt.ylabel('Loss')\n",
107+
" plt.title('Training and Validation Loss Over Epochs')\n",
108+
" plt.legend()\n",
109+
" plt.grid()\n",
110+
" plt.show()\n",
111+
"\n",
112+
"# Initialize VHR-10 dataset\n",
113+
"train_dataset = VHR10(root=\"data\", split=\"positive\", transforms=None, download=True)\n",
114+
"val_dataset = VHR10(root=\"data\", split=\"positive\", transforms=None)\n",
115+
"\n",
116+
"# Subset for quick experimentation (adjust N as needed)\n",
117+
"N = 100\n",
118+
"train_subset = Subset(train_dataset, list(range(N)))\n",
119+
"val_subset = Subset(val_dataset, list(range(N)))\n",
120+
"\n",
121+
"if __name__ == '__main__':\n",
122+
" import multiprocessing\n",
123+
" multiprocessing.set_start_method('spawn', force=True)\n",
124+
"\n",
125+
" train_loader = DataLoader(train_subset, batch_size=8, shuffle=True, num_workers=1, collate_fn=collate_fn)\n",
126+
" val_loader = DataLoader(val_subset, batch_size=8, shuffle=False, num_workers=1, collate_fn=collate_fn)\n",
127+
"\n",
128+
" # Trainer setup\n",
129+
" trainer = pl.Trainer(\n",
130+
" max_epochs=5,\n",
131+
" accelerator=\"gpu\" if torch.cuda.is_available() else \"cpu\",\n",
132+
" devices=1\n",
133+
" )\n",
134+
"\n",
135+
" task = InstanceSegmentationTask(\n",
136+
" model=\"mask_rcnn\",\n",
137+
" backbone=\"resnet50\",\n",
138+
" weights=\"imagenet\", # Pretrained on ImageNet\n",
139+
" num_classes=11, # VHR-10 has 10 classes + 1 background\n",
140+
" lr=1e-3,\n",
141+
" freeze_backbone=False\n",
142+
" )\n",
143+
"\n",
144+
" print('\\nSTART TRAINING\\n')\n",
145+
" # trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader)\n",
146+
" train_losses, val_losses = [], []\n",
147+
" for epoch in range(5):\n",
148+
" trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader)\n",
149+
" train_loss = task.trainer.callback_metrics.get(\"train_loss\")\n",
150+
" val_loss = task.trainer.callback_metrics.get(\"val_loss\")\n",
151+
" if train_loss is not None:\n",
152+
" train_losses.append(train_loss.item())\n",
153+
" if val_loss is not None:\n",
154+
" val_losses.append(val_loss.item())\n",
155+
"\n",
156+
" plot_losses(train_losses, val_losses)\n",
157+
"\n",
158+
" #trainer.test(task, dataloaders=val_loader)\n",
159+
"\n",
160+
" # Inference and Visualization\n",
161+
" sample = train_dataset[1]\n",
162+
" image = sample['image'].unsqueeze(0)\n",
163+
" predictions = task.predict_step({\"image\": image}, batch_idx=0)\n",
164+
" visualize_predictions(image[0], predictions[0], sample)\n",
165+
"\n"
166+
]
167+
}
168+
],
169+
"metadata": {
170+
"kernelspec": {
171+
"display_name": "Python 3",
172+
"name": "python3"
173+
},
174+
"language_info": {
175+
"name": "python",
176+
"version": "3.12.0"
177+
},
178+
"colab": {
179+
"provenance": [],
180+
"gpuType": "T4"
181+
},
182+
"accelerator": "GPU"
142183
},
143-
"language_info": {
144-
"name": "python",
145-
"version": "3.12.0"
146-
}
147-
},
148-
"nbformat": 4,
149-
"nbformat_minor": 2
150-
}
184+
"nbformat": 4,
185+
"nbformat_minor": 0
186+
}

0 commit comments

Comments
 (0)