Skip to content

Commit fedf993

Browse files
authored
TreeSatAI: Add new dataset (#2402)
* TreeSatAI: Add new dataset * Add versionadded * ruff * mypy fix * Add docs * Add test data * Add zip data * Add labels zip * Add tests * Add data module and tests * Capitalization is important * Fix uint casting to float * Use many-hot encoding to support multilabel classification * Silence overflow warnings * Add docs for data module
1 parent 6baa00d commit fedf993

File tree

54 files changed

+666
-1
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+666
-1
lines changed

docs/api/datamodules.rst

+5

docs/api/datasets.rst

+5

docs/api/datasets/non_geo_datasets.csv

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
5353
`SSL4EO-L Benchmark`_,S,Lansat & CDL,"CC0-1.0",25K,134,264x264,30,MSI
5454
`SSL4EO-L Benchmark`_,S,Lansat & NLCD,"CC0-1.0",25K,17,264x264,30,MSI
5555
`SustainBench Crop Yield`_,R,MODIS,"CC-BY-SA-4.0",11k,-,32x32,-,MSI
56+
`TreeSatAI`_,"C, R, S","Aerial, Sentinel-1/2",CC-BY-4.0,50K,"12, 15, 20","6, 20, 304","0.2, 10","CIR, MSI, SAR"
5657
`Tropical Cyclone`_,R,GOES 8--16,"CC-BY-4.0","108,110",-,256x256,4K--8K,MSI
5758
`UC Merced`_,C,USGS National Map,"public domain","2,100",21,256x256,0.3,RGB
5859
`USAVars`_,R,NAIP Aerial,"CC-BY-4.0",100K,-,-,4,"RGB, NIR"

tests/conf/treesatai.yaml

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
model:
2+
class_path: MultiLabelClassificationTask
3+
init_args:
4+
model: 'resnet18'
5+
in_channels: 19
6+
num_classes: 15
7+
loss: 'bce'
8+
data:
9+
class_path: TreeSatAIDataModule
10+
init_args:
11+
batch_size: 1
12+
dict_kwargs:
13+
root: 'tests/data/treesatai'
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
22 Bytes
Binary file not shown.
Binary file not shown.
9.05 KB
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

tests/data/treesatai/data.py

+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Microsoft Corporation. All rights reserved.
4+
# Licensed under the MIT License.
5+
6+
import glob
7+
import json
8+
import os
9+
import random
10+
import shutil
11+
import zipfile
12+
13+
import numpy as np
14+
import rasterio
15+
from rasterio import Affine
16+
from rasterio.crs import CRS
17+
18+
SIZE = 32
19+
20+
random.seed(0)
21+
np.random.seed(0)
22+
23+
classes = (
24+
'Abies',
25+
'Acer',
26+
'Alnus',
27+
'Betula',
28+
'Cleared',
29+
'Fagus',
30+
'Fraxinus',
31+
'Larix',
32+
'Picea',
33+
'Pinus',
34+
'Populus',
35+
'Prunus',
36+
'Pseudotsuga',
37+
'Quercus',
38+
'Tilia',
39+
)
40+
41+
species = (
42+
'Acer_pseudoplatanus',
43+
'Alnus_spec',
44+
'Fagus_sylvatica',
45+
'Picea_abies',
46+
'Pseudotsuga_menziesii',
47+
'Quercus_petraea',
48+
'Quercus_rubra',
49+
)
50+
51+
profile = {
52+
'aerial': {
53+
'driver': 'GTiff',
54+
'dtype': 'uint8',
55+
'nodata': None,
56+
'width': SIZE,
57+
'height': SIZE,
58+
'count': 4,
59+
'crs': CRS.from_epsg(25832),
60+
'transform': Affine(
61+
0.19999999999977022, 0.0, 552245.4, 0.0, -0.19999999999938728, 5728215.0
62+
),
63+
},
64+
's1': {
65+
'driver': 'GTiff',
66+
'dtype': 'float32',
67+
'nodata': -9999.0,
68+
'width': SIZE // 16,
69+
'height': SIZE // 16,
70+
'count': 3,
71+
'crs': CRS.from_epsg(32632),
72+
'transform': Affine(10.0, 0.0, 552245.0, 0.0, -10.0, 5728215.0),
73+
},
74+
's2': {
75+
'driver': 'GTiff',
76+
'dtype': 'uint16',
77+
'nodata': None,
78+
'width': SIZE // 16,
79+
'height': SIZE // 16,
80+
'count': 12,
81+
'crs': CRS.from_epsg(32632),
82+
'transform': Affine(10.0, 0.0, 552241.6565, 0.0, -10.0, 5728211.6251),
83+
},
84+
}
85+
86+
multi_labels = {}
87+
for split in ['train', 'test']:
88+
with open(f'{split}_filenames.lst') as f:
89+
for filename in f:
90+
filename = filename.strip()
91+
for sensor in ['aerial', 's1', 's2']:
92+
kwargs = profile[sensor]
93+
directory = os.path.join(sensor, '60m')
94+
os.makedirs(directory, exist_ok=True)
95+
if 'int' in kwargs['dtype']:
96+
Z = np.random.randint(
97+
np.iinfo(kwargs['dtype']).min,
98+
np.iinfo(kwargs['dtype']).max,
99+
size=(kwargs['height'], kwargs['width']),
100+
dtype=kwargs['dtype'],
101+
)
102+
else:
103+
Z = np.random.rand(kwargs['height'], kwargs['width'])
104+
105+
path = os.path.join(directory, filename)
106+
with rasterio.open(path, 'w', **kwargs) as src:
107+
for i in range(1, kwargs['count'] + 1):
108+
src.write(Z, i)
109+
110+
k = random.randrange(1, 4)
111+
labels = random.choices(classes, k=k)
112+
pcts = np.random.rand(k)
113+
pcts /= np.sum(pcts)
114+
multi_labels[filename] = list(map(list, zip(labels, map(float, pcts))))
115+
116+
os.makedirs('labels', exist_ok=True)
117+
path = os.path.join('labels', 'TreeSatBA_v9_60m_multi_labels.json')
118+
with open(path, 'w') as f:
119+
json.dump(multi_labels, f)
120+
121+
for sensor in ['s1', 's2', 'labels']:
122+
shutil.make_archive(sensor, 'zip', '.', sensor)
123+
124+
for spec in species:
125+
path = f'aerial_60m_{spec}.zip'.lower()
126+
with zipfile.ZipFile(path, 'w') as f:
127+
for path in glob.iglob(os.path.join('aerial', '60m', f'{spec}_*.tif')):
128+
filename = os.path.split(path)[-1]
129+
f.write(path, arcname=filename)

tests/data/treesatai/labels.zip

750 Bytes
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"Picea_abies_3_46636_WEFL_NLF.tif": [["Prunus", 0.20692122963708826], ["Fraxinus", 0.7930787703629117]], "Pseudotsuga_menziesii_1_339575_BI_NLF.tif": [["Tilia", 0.4243067837573989], ["Larix", 0.5756932162426011]], "Quercus_rubra_1_92184_WEFL_NLF.tif": [["Tilia", 0.5816157697641007], ["Fagus", 0.4183842302358993]], "Fagus_sylvatica_9_29995_WEFL_NLF.tif": [["Larix", 1.0]], "Quercus_petraea_5_80549_WEFL_NLF.tif": [["Alnus", 0.5749721529276662], ["Acer", 0.4250278470723338]], "Acer_pseudoplatanus_3_5758_WEFL_NLF.tif": [["Tilia", 0.8430361090251272], ["Larix", 0.1569638909748729]], "Alnus_spec._5_13114_WEFL_NLF.tif": [["Pseudotsuga", 0.17881149698366108], ["Quercus", 0.38732907538618866], ["Cleared", 0.4338594276301503]], "Quercus_petraea_2_84375_WEFL_NLF.tif": [["Acer", 0.3909090505343164], ["Pseudotsuga", 0.2628926194326892], ["Cleared", 0.34619833003299444]], "Picea_abies_2_46896_WEFL_NLF.tif": [["Acer", 0.4953810312272686], ["Fraxinus", 0.0006659055704136941], ["Pinus", 0.5039530632023177]], "Acer_pseudoplatanus_4_6058_WEFL_NLF.tif": [["Tilia", 1.0]]}

tests/data/treesatai/s1.zip

4.24 KB
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

tests/data/treesatai/s2.zip

4.15 KB
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Acer_pseudoplatanus_4_6058_WEFL_NLF.tif
+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
Picea_abies_3_46636_WEFL_NLF.tif
2+
Pseudotsuga_menziesii_1_339575_BI_NLF.tif
3+
Quercus_rubra_1_92184_WEFL_NLF.tif
4+
Fagus_sylvatica_9_29995_WEFL_NLF.tif
5+
Quercus_petraea_5_80549_WEFL_NLF.tif
6+
Acer_pseudoplatanus_3_5758_WEFL_NLF.tif
7+
Alnus_spec._5_13114_WEFL_NLF.tif
8+
Quercus_petraea_2_84375_WEFL_NLF.tif
9+
Picea_abies_2_46896_WEFL_NLF.tif

tests/datasets/test_treesatai.py

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
import glob
5+
import os
6+
import shutil
7+
from pathlib import Path
8+
9+
import matplotlib.pyplot as plt
10+
import pytest
11+
import torch.nn as nn
12+
from pytest import MonkeyPatch
13+
from torch import Tensor
14+
15+
from torchgeo.datasets import DatasetNotFoundError, TreeSatAI
16+
17+
root = os.path.join('tests', 'data', 'treesatai')
18+
md5s = {
19+
'aerial_60m_acer_pseudoplatanus.zip': '',
20+
'labels.zip': '',
21+
's1.zip': '',
22+
's2.zip': '',
23+
'test_filenames.lst': '',
24+
'train_filenames.lst': '',
25+
}
26+
27+
28+
class TestTreeSatAI:
29+
@pytest.fixture
30+
def dataset(self, monkeypatch: MonkeyPatch) -> TreeSatAI:
31+
monkeypatch.setattr(TreeSatAI, 'url', root + os.sep)
32+
monkeypatch.setattr(TreeSatAI, 'md5s', md5s)
33+
transforms = nn.Identity()
34+
return TreeSatAI(root, transforms=transforms)
35+
36+
def test_getitem(self, dataset: TreeSatAI) -> None:
37+
x = dataset[0]
38+
assert isinstance(x, dict)
39+
assert isinstance(x['label'], Tensor)
40+
for sensor in dataset.sensors:
41+
assert isinstance(x[f'image_{sensor}'], Tensor)
42+
43+
def test_len(self, dataset: TreeSatAI) -> None:
44+
assert len(dataset) == 9
45+
46+
def test_download(self, dataset: TreeSatAI, tmp_path: Path) -> None:
47+
TreeSatAI(tmp_path, download=True)
48+
49+
def test_extract(self, dataset: TreeSatAI, tmp_path: Path) -> None:
50+
for file in glob.iglob(os.path.join(root, '*.*')):
51+
shutil.copy(file, tmp_path)
52+
TreeSatAI(tmp_path)
53+
54+
def test_not_downloaded(self, tmp_path: Path) -> None:
55+
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
56+
TreeSatAI(tmp_path)
57+
58+
def test_plot(self, dataset: TreeSatAI) -> None:
59+
x = dataset[0]
60+
x['prediction'] = x['label']
61+
dataset.plot(x)
62+
plt.close()

tests/trainers/test_classification.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def test_freeze_backbone(self, model_name: str) -> None:
237237

238238
class TestMultiLabelClassificationTask:
239239
@pytest.mark.parametrize(
240-
'name', ['bigearthnet_all', 'bigearthnet_s1', 'bigearthnet_s2']
240+
'name', ['bigearthnet_all', 'bigearthnet_s1', 'bigearthnet_s2', 'treesatai']
241241
)
242242
def test_trainer(
243243
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool

torchgeo/datamodules/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from .ssl4eo import SSL4EOLDataModule, SSL4EOS12DataModule
4949
from .ssl4eo_benchmark import SSL4EOLBenchmarkDataModule
5050
from .sustainbench_crop_yield import SustainBenchCropYieldDataModule
51+
from .treesatai import TreeSatAIDataModule
5152
from .ucmerced import UCMercedDataModule
5253
from .usavars import USAVarsDataModule
5354
from .utils import MisconfigurationException
@@ -110,6 +111,7 @@
110111
'SpaceNet6DataModule',
111112
'SpaceNetBaseDataModule',
112113
'SustainBenchCropYieldDataModule',
114+
'TreeSatAIDataModule',
113115
'TropicalCycloneDataModule',
114116
'UCMercedDataModule',
115117
'USAVarsDataModule',

0 commit comments

Comments
 (0)