Skip to content

Commit e0a6912

Browse files
committed
Add TorchGeo CLI tutorial (#2479)
* Add TorchGeo CLI tutorial * Pass checkpoint path
1 parent ec96458 commit e0a6912

File tree

2 files changed

+294
-0
lines changed

2 files changed

+294
-0
lines changed

docs/tutorials/basic_usage.rst

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ The following tutorials introduce the basic concepts and components of TorchGeo:
77
* `Spectral Indices <indices.ipynb>`_: Visualizing and appending spectral indices
88
* `Pretrained Weights <pretrained_weights.ipynb>`_: Models and pretrained weights
99
* `Lightning Trainers <trainers.ipynb>`_: PyTorch Lightning data modules and trainers
10+
* `Command-Line Interface <cli.ipynb>`_: TorchGeo's command-line interface
1011

1112
.. toctree::
1213
:hidden:
@@ -16,3 +17,4 @@ The following tutorials introduce the basic concepts and components of TorchGeo:
1617
indices
1718
pretrained_weights
1819
trainers
20+
cli

docs/tutorials/cli.ipynb

+292
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "16421d50-8d7a-4972-b06f-160fd890cc86",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"# Copyright (c) Microsoft Corporation. All rights reserved.\n",
11+
"# Licensed under the MIT License."
12+
]
13+
},
14+
{
15+
"cell_type": "markdown",
16+
"id": "e563313d",
17+
"metadata": {},
18+
"source": [
19+
"# Command-Line Interface\n",
20+
"\n",
21+
"_Written by: Adam J. Stewart_\n",
22+
"\n",
23+
"TorchGeo provides a command-line interface based on [LightningCLI](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.cli.LightningCLI.html) that allows users to combine our data modules and trainers from the comfort of the command line. This no-code solution can be attractive for both beginners and experts, as it offers flexibility and reproducibility. In this tutorial, we demonstrate some of the features of this interface."
24+
]
25+
},
26+
{
27+
"cell_type": "markdown",
28+
"id": "8c1f4156",
29+
"metadata": {},
30+
"source": [
31+
"## Setup\n",
32+
"\n",
33+
"First, we install TorchGeo. In addition to the Python library, this also installs a `torchgeo` executable."
34+
]
35+
},
36+
{
37+
"cell_type": "code",
38+
"execution_count": null,
39+
"id": "3f0d31a8",
40+
"metadata": {},
41+
"outputs": [],
42+
"source": [
43+
"%pip install torchgeo"
44+
]
45+
},
46+
{
47+
"cell_type": "markdown",
48+
"id": "7801ab8b-0ee3-40ac-88c2-4bdc29bb4e1b",
49+
"metadata": {},
50+
"source": [
51+
"## Subcommands\n",
52+
"\n",
53+
"The `torchgeo` command has a number of *subcommands* that can be run. The `--help` flag can be used to list them."
54+
]
55+
},
56+
{
57+
"cell_type": "code",
58+
"execution_count": null,
59+
"id": "a6ccac4e-7f20-4aa8-b851-27234ffd259f",
60+
"metadata": {},
61+
"outputs": [],
62+
"source": [
63+
"!torchgeo --help"
64+
]
65+
},
66+
{
67+
"cell_type": "markdown",
68+
"id": "19ee017d-0d8f-41c6-8e7c-68495c7e62b6",
69+
"metadata": {},
70+
"source": [
71+
"## Trainer\n",
72+
"\n",
73+
"Below, we run `--help` on the `fit` subcommand to see what options are available to us. `fit` is used to train and validate a model, and we can customize many aspects of the training process."
74+
]
75+
},
76+
{
77+
"cell_type": "code",
78+
"execution_count": null,
79+
"id": "afe1dc9d-4cee-43b0-ae30-200c64d3401a",
80+
"metadata": {},
81+
"outputs": [],
82+
"source": [
83+
"!torchgeo fit --help"
84+
]
85+
},
86+
{
87+
"cell_type": "markdown",
88+
"id": "b437860c-b406-4150-b30b-8aa895eebfcd",
89+
"metadata": {},
90+
"source": [
91+
"## Model\n",
92+
"\n",
93+
"We must first select an `nn.Module` model architecture to train and a `lightning.pytorch.LightningModule` trainer to train it. We will experiment with the `ClassificationTask` trainer and see what options we can customize. Any of TorchGeo's builtin trainers, or trainers written by the user, can be used in this way."
94+
]
95+
},
96+
{
97+
"cell_type": "code",
98+
"execution_count": null,
99+
"id": "7cd9bbd0-17c9-4e87-b10d-ea846c39bc24",
100+
"metadata": {},
101+
"outputs": [],
102+
"source": [
103+
"!torchgeo fit --model.help ClassificationTask"
104+
]
105+
},
106+
{
107+
"cell_type": "markdown",
108+
"id": "3daacd8d-64f4-4357-bdf3-759295a14224",
109+
"metadata": {},
110+
"source": [
111+
"## Data\n",
112+
"\n",
113+
"We must also select a `Dataset` we would like to train on and a `lightning.pytorch.LightningDataModule` we can use to access the train/val/test split and any augmentations to apply to the data. Similarly, we use the `--help` flag to see what options are available for the `EuroSAT100` dataset."
114+
]
115+
},
116+
{
117+
"cell_type": "code",
118+
"execution_count": null,
119+
"id": "136eb59f-6662-44af-82e9-c55bdb3f17ac",
120+
"metadata": {},
121+
"outputs": [],
122+
"source": [
123+
"!torchgeo fit --data.help EuroSAT100DataModule"
124+
]
125+
},
126+
{
127+
"cell_type": "markdown",
128+
"id": "8039cb67-ee18-4b41-8bf5-0e939493f5bb",
129+
"metadata": {},
130+
"source": [
131+
"## Config\n",
132+
"\n",
133+
"Now that we have seen all important configuration options, we can put them together in a YAML file. LightingCLI supports YAML, JSON, and command-line configuration. While we will write this file using Python in this tutorial, normally this file would be written in your favorite text editor."
134+
]
135+
},
136+
{
137+
"cell_type": "code",
138+
"execution_count": null,
139+
"id": "e25c8efb-ed8c-4795-862c-bfb84cc84e1f",
140+
"metadata": {},
141+
"outputs": [],
142+
"source": [
143+
"import os\n",
144+
"import tempfile\n",
145+
"\n",
146+
"root = os.path.join(tempfile.gettempdir(), 'eurosat100')\n",
147+
"config = f\"\"\"\n",
148+
"trainer:\n",
149+
" max_epochs: 1\n",
150+
" default_root_dir: '{root}'\n",
151+
"model:\n",
152+
" class_path: ClassificationTask\n",
153+
" init_args:\n",
154+
" model: 'resnet18'\n",
155+
" in_channels: 13\n",
156+
" num_classes: 10\n",
157+
"data:\n",
158+
" class_path: EuroSAT100DataModule\n",
159+
" init_args:\n",
160+
" batch_size: 8\n",
161+
" dict_kwargs:\n",
162+
" root: '{root}'\n",
163+
" download: true\n",
164+
"\"\"\"\n",
165+
"os.makedirs(root, exist_ok=True)\n",
166+
"with open(os.path.join(root, 'config.yaml'), 'w') as f:\n",
167+
" f.write(config)"
168+
]
169+
},
170+
{
171+
"cell_type": "markdown",
172+
"id": "a661b8d7-2dc9-4a30-8842-bd52d130e080",
173+
"metadata": {},
174+
"source": [
175+
"This YAML file has three sections:\n",
176+
"\n",
177+
"* trainer: Arguments to pass to the [Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html)\n",
178+
"* model: Arguments to pass to the task\n",
179+
"* data: Arguments to pass to the data module\n",
180+
"\n",
181+
"The `class_path` gives the class to instantiate, `init_args` lists standard arguments, and `dict_kwargs` lists keyword arguments."
182+
]
183+
},
184+
{
185+
"cell_type": "markdown",
186+
"id": "e132f933-4edf-42bb-b585-e0d8ceb65eab",
187+
"metadata": {},
188+
"source": [
189+
"## Training\n",
190+
"\n",
191+
"We can now train our model like so."
192+
]
193+
},
194+
{
195+
"cell_type": "code",
196+
"execution_count": null,
197+
"id": "f84b0739-c9e7-4057-8864-98ab69a11f64",
198+
"metadata": {},
199+
"outputs": [],
200+
"source": [
201+
"!torchgeo fit --config {root}/config.yaml"
202+
]
203+
},
204+
{
205+
"cell_type": "markdown",
206+
"id": "cb1557f1-6cc0-46da-909c-836911acb248",
207+
"metadata": {},
208+
"source": [
209+
"## Validation\n",
210+
"\n",
211+
"Now that we have a trained model, we can evaluate performance on the validation set. Note that we need to explicitly pass in the location of the checkpoint from the previous run."
212+
]
213+
},
214+
{
215+
"cell_type": "code",
216+
"execution_count": null,
217+
"id": "b9cbb4f4-1879-4ae7-bae4-2c24d49a4a61",
218+
"metadata": {},
219+
"outputs": [],
220+
"source": [
221+
"import glob\n",
222+
"\n",
223+
"checkpoint = glob.glob(\n",
224+
" os.path.join(root, 'lightning_logs', 'version_0', 'checkpoints', '*.ckpt')\n",
225+
")[0]\n",
226+
"\n",
227+
"!torchgeo validate --config {root}/config.yaml --ckpt_path {checkpoint}"
228+
]
229+
},
230+
{
231+
"cell_type": "markdown",
232+
"id": "ba816fc3-5cac-4cbc-a6ef-effc6c9faa61",
233+
"metadata": {},
234+
"source": [
235+
"## Testing\n",
236+
"\n",
237+
"After finishing our hyperparameter tuning, we can calculate and report the final test performance."
238+
]
239+
},
240+
{
241+
"cell_type": "code",
242+
"execution_count": null,
243+
"id": "f1faa997-9f81-4847-94fc-5a8bb7687369",
244+
"metadata": {},
245+
"outputs": [],
246+
"source": [
247+
"!torchgeo test --config {root}/config.yaml --ckpt_path {checkpoint}"
248+
]
249+
},
250+
{
251+
"cell_type": "markdown",
252+
"id": "f5383d30-8f76-44a2-8366-e6fcbd1e6042",
253+
"metadata": {},
254+
"source": [
255+
"## Additional Reading\n",
256+
"\n",
257+
"Lightning CLI has many more features that are worth learning. You can learn more by reading the following set of tutorials:\n",
258+
"\n",
259+
"* [Configure hyperparameters from the CLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html)"
260+
]
261+
}
262+
],
263+
"metadata": {
264+
"accelerator": "GPU",
265+
"colab": {
266+
"provenance": []
267+
},
268+
"execution": {
269+
"timeout": 1200
270+
},
271+
"gpuClass": "standard",
272+
"kernelspec": {
273+
"display_name": "Python 3 (ipykernel)",
274+
"language": "python",
275+
"name": "python3"
276+
},
277+
"language_info": {
278+
"codemirror_mode": {
279+
"name": "ipython",
280+
"version": 3
281+
},
282+
"file_extension": ".py",
283+
"mimetype": "text/x-python",
284+
"name": "python",
285+
"nbconvert_exporter": "python",
286+
"pygments_lexer": "ipython3",
287+
"version": "3.13.0"
288+
}
289+
},
290+
"nbformat": 4,
291+
"nbformat_minor": 5
292+
}

0 commit comments

Comments
 (0)