Skip to content

Commit 70074e7

Browse files
Add files via upload
1 parent 63aefc8 commit 70074e7

File tree

1 file changed

+150
-0
lines changed

1 file changed

+150
-0
lines changed

test_trainer.ipynb

+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import torch\n",
10+
"import lightning.pytorch as pl\n",
11+
"from torch.utils.data import DataLoader, Subset\n",
12+
"from torchgeo.datasets import VHR10\n",
13+
"from torchvision.transforms.functional import to_pil_image\n",
14+
"from matplotlib.patches import Rectangle\n",
15+
"import matplotlib.pyplot as plt\n",
16+
"import torch.nn.functional as F\n",
17+
"from torchgeo.trainers import InstanceSegmentationTask \n",
18+
"import matplotlib.patches as patches\n",
19+
"import numpy as np\n",
20+
"\n",
21+
"def collate_fn(batch):\n",
22+
" \"\"\"Custom collate function for DataLoader.\"\"\"\n",
23+
" max_height = max(sample['image'].shape[1] for sample in batch)\n",
24+
" max_width = max(sample['image'].shape[2] for sample in batch)\n",
25+
"\n",
26+
" images = torch.stack([\n",
27+
" F.pad(sample['image'], (0, max_width - sample['image'].shape[2], 0, max_height - sample['image'].shape[1]))\n",
28+
" for sample in batch\n",
29+
" ])\n",
30+
"\n",
31+
" targets = [\n",
32+
" {\n",
33+
" \"labels\": sample[\"labels\"].to(torch.int64),\n",
34+
" \"boxes\": sample[\"boxes\"].to(torch.float32),\n",
35+
" \"masks\": F.pad(\n",
36+
" sample[\"masks\"],\n",
37+
" (0, max_width - sample[\"masks\"].shape[2], 0, max_height - sample[\"masks\"].shape[1]),\n",
38+
" ).to(torch.uint8),\n",
39+
" }\n",
40+
" for sample in batch\n",
41+
" ]\n",
42+
"\n",
43+
" return {\"image\": images, \"target\": targets}\n",
44+
"\n",
45+
"def visualize_predictions(image, predictions, targets):\n",
46+
" \"\"\"Visualize predictions and ground truth.\"\"\"\n",
47+
" image = to_pil_image(image)\n",
48+
"\n",
49+
" fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
50+
" ax.imshow(image)\n",
51+
"\n",
52+
" # Predictions\n",
53+
" for box, label in zip(predictions['boxes'], predictions['labels']):\n",
54+
" x1, y1, x2, y2 = box\n",
55+
" rect = Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='red', facecolor='none')\n",
56+
" ax.add_patch(rect)\n",
57+
" ax.text(x1, y1, f\"Pred: {label.item()}\", color='red', fontsize=12)\n",
58+
"\n",
59+
" # Ground truth\n",
60+
" for box, label in zip(targets['boxes'], targets['labels']):\n",
61+
" x1, y1, x2, y2 = box\n",
62+
" rect = Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='blue', facecolor='none')\n",
63+
" ax.add_patch(rect)\n",
64+
" ax.text(x1, y1, f\"GT: {label.item()}\", color='blue', fontsize=12)\n",
65+
"\n",
66+
" plt.show()\n",
67+
"\n",
68+
"def plot_losses(train_losses, val_losses):\n",
69+
" \"\"\"Plot training and validation losses over epochs.\"\"\"\n",
70+
" plt.figure(figsize=(10, 5))\n",
71+
" plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss', marker='o')\n",
72+
" plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss', marker='s')\n",
73+
" plt.xlabel('Epochs')\n",
74+
" plt.ylabel('Loss')\n",
75+
" plt.title('Training and Validation Loss Over Epochs')\n",
76+
" plt.legend()\n",
77+
" plt.grid()\n",
78+
" plt.show()\n",
79+
"\n",
80+
"# Initialize VHR-10 dataset\n",
81+
"train_dataset = VHR10(root=\"data\", split=\"positive\", transforms=None, download=True)\n",
82+
"val_dataset = VHR10(root=\"data\", split=\"positive\", transforms=None)\n",
83+
"\n",
84+
"# Subset for quick experimentation (adjust N as needed)\n",
85+
"N = 100\n",
86+
"train_subset = Subset(train_dataset, list(range(N)))\n",
87+
"val_subset = Subset(val_dataset, list(range(N)))\n",
88+
"\n",
89+
"if __name__ == '__main__':\n",
90+
" import multiprocessing\n",
91+
" multiprocessing.set_start_method('spawn', force=True)\n",
92+
"\n",
93+
" train_loader = DataLoader(train_subset, batch_size=8, shuffle=True, num_workers=1, collate_fn=collate_fn)\n",
94+
" val_loader = DataLoader(val_subset, batch_size=8, shuffle=False, num_workers=1, collate_fn=collate_fn)\n",
95+
"\n",
96+
" # Trainer setup\n",
97+
" trainer = pl.Trainer(\n",
98+
" max_epochs=5, \n",
99+
" accelerator=\"gpu\" if torch.cuda.is_available() else \"cpu\",\n",
100+
" devices=1\n",
101+
" )\n",
102+
"\n",
103+
" task = InstanceSegmentationTask(\n",
104+
" model=\"mask_rcnn\", \n",
105+
" backbone=\"resnet50\", \n",
106+
" weights=\"imagenet\", # Pretrained on ImageNet\n",
107+
" num_classes=11, # VHR-10 has 10 classes + 1 background\n",
108+
" lr=1e-3, \n",
109+
" freeze_backbone=False \n",
110+
" )\n",
111+
"\n",
112+
" print('\\nSTART TRAINING\\n')\n",
113+
" # trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader)\n",
114+
" train_losses, val_losses = [], []\n",
115+
" for epoch in range(5):\n",
116+
" trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader)\n",
117+
" train_loss = task.trainer.callback_metrics.get(\"train_loss\")\n",
118+
" val_loss = task.trainer.callback_metrics.get(\"val_loss\")\n",
119+
" if train_loss is not None:\n",
120+
" train_losses.append(train_loss.item())\n",
121+
" if val_loss is not None:\n",
122+
" val_losses.append(val_loss.item())\n",
123+
" \n",
124+
" plot_losses(train_losses, val_losses)\n",
125+
"\n",
126+
" #trainer.test(task, dataloaders=val_loader)\n",
127+
"\n",
128+
" # Inference and Visualization\n",
129+
" sample = train_dataset[1]\n",
130+
" image = sample['image'].unsqueeze(0) \n",
131+
" predictions = task.predict_step({\"image\": image}, batch_idx=0)\n",
132+
" visualize_predictions(image[0], predictions[0], sample)\n",
133+
"\n"
134+
]
135+
}
136+
],
137+
"metadata": {
138+
"kernelspec": {
139+
"display_name": "Python 3",
140+
"language": "python",
141+
"name": "python3"
142+
},
143+
"language_info": {
144+
"name": "python",
145+
"version": "3.12.0"
146+
}
147+
},
148+
"nbformat": 4,
149+
"nbformat_minor": 2
150+
}

0 commit comments

Comments
 (0)