Skip to content

Commit e42c404

Browse files
Add BRIGHT dataset (#2520)
* bright * bright tests * bright * run ruff * mypy and docs * ruff on data.py * ruff on bright * docs * ruff * rm datamodule * coverage * request * Update docs/api/datasets/non_geo_datasets.csv Co-authored-by: Adam J. Stewart <[email protected]> --------- Co-authored-by: Adam J. Stewart <[email protected]>
1 parent 662a883 commit e42c404

22 files changed

+560
-0
lines changed

docs/api/datasets.rst

+5
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,11 @@ BioMassters
221221

222222
.. autoclass:: BioMassters
223223

224+
BRIGHT
225+
^^^^^^
226+
227+
.. autoclass:: BRIGHTDFC2025
228+
224229
CaBuAr
225230
^^^^^^
226231

docs/api/datasets/non_geo_datasets.csv

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
33
`Benin Cashew Plantations`_,S,Airbus Pléiades,"CC-BY-4.0",70,6,"1,122x1,186",10,MSI
44
`BigEarthNet`_,C,Sentinel-1/2,"CDLA-Permissive-1.0","590,326",19--43,120x120,10,"SAR, MSI"
55
`BioMassters`_,R,Sentinel-1/2 and Lidar,"CC-BY-4.0",,,256x256, 10, "SAR, MSI"
6+
`BRIGHT`_,CD,"MAXAR, NAIP, Capella, Umbra","CC-BY-4.0 AND CC-BY-NC-4.0",3239,4,"0.1-1","RGB,SAR"
67
`CaBuAr`_,CD,Sentinel-2,"OpenRAIL",424,2,512x512,20,MSI
78
`CaFFe`_,S,"Sentinel-1, TerraSAR-X, TanDEM-X, ENVISAT, ERS-1/2, ALOS PALSAR, and RADARSAT-1","CC-BY-4.0","19092","2 or 4","512x512",6-20,"SAR"
89
`ChaBuD`_,CD,Sentinel-2,"OpenRAIL",356,2,512x512,10,MSI

tests/data/bright/data.py

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Microsoft Corporation. All rights reserved.
4+
# Licensed under the MIT License.
5+
6+
import hashlib
7+
import os
8+
import shutil
9+
10+
import numpy as np
11+
import rasterio
12+
13+
ROOT = '.'
14+
DATA_DIR = 'dfc25_track2_trainval'
15+
16+
TRAIN_FILE = 'train_setlevel.txt'
17+
HOLDOUT_FILE = 'holdout_setlevel.txt'
18+
VAL_FILE = 'val_setlevel.txt'
19+
20+
TRAIN_IDS = [
21+
'bata-explosion_00000049',
22+
'bata-explosion_00000014',
23+
'bata-explosion_00000047',
24+
]
25+
HOLDOUT_IDS = ['turkey-earthquake_00000413']
26+
VAL_IDS = ['val-disaster_00000001', 'val-disaster_00000002']
27+
28+
SIZE = 32
29+
30+
31+
def make_dirs() -> None:
32+
paths = [
33+
os.path.join(ROOT, DATA_DIR),
34+
os.path.join(ROOT, DATA_DIR, 'train', 'pre-event'),
35+
os.path.join(ROOT, DATA_DIR, 'train', 'post-event'),
36+
os.path.join(ROOT, DATA_DIR, 'train', 'target'),
37+
os.path.join(ROOT, DATA_DIR, 'val', 'pre-event'),
38+
os.path.join(ROOT, DATA_DIR, 'val', 'post-event'),
39+
os.path.join(ROOT, DATA_DIR, 'val', 'target'),
40+
]
41+
for p in paths:
42+
os.makedirs(p, exist_ok=True)
43+
44+
45+
def write_list_file(filename: str, ids: list[str]) -> None:
46+
file_path = os.path.join(ROOT, DATA_DIR, filename)
47+
with open(file_path, 'w') as f:
48+
for sid in ids:
49+
f.write(f'{sid}\n')
50+
51+
52+
def write_tif(filepath: str, channels: int) -> None:
53+
data = np.random.randint(0, 255, (channels, SIZE, SIZE), dtype=np.uint8)
54+
# transform = from_origin(0, 0, 1, 1)
55+
crs = 'epsg:4326'
56+
with rasterio.open(
57+
filepath,
58+
'w',
59+
driver='GTiff',
60+
height=SIZE,
61+
width=SIZE,
62+
count=channels,
63+
crs=crs,
64+
dtype=data.dtype,
65+
compress='lzw',
66+
# transform=transform,
67+
) as dst:
68+
dst.write(data)
69+
70+
71+
def populate_data(ids: list[str], dir_name: str, with_target: bool = True) -> None:
72+
for sid in ids:
73+
pre_path = os.path.join(
74+
ROOT, DATA_DIR, dir_name, 'pre-event', f'{sid}_pre_disaster.tif'
75+
)
76+
write_tif(pre_path, channels=3)
77+
post_path = os.path.join(
78+
ROOT, DATA_DIR, dir_name, 'post-event', f'{sid}_post_disaster.tif'
79+
)
80+
write_tif(post_path, channels=1)
81+
if with_target:
82+
target_path = os.path.join(
83+
ROOT, DATA_DIR, dir_name, 'target', f'{sid}_building_damage.tif'
84+
)
85+
write_tif(target_path, channels=1)
86+
87+
88+
def main() -> None:
89+
make_dirs()
90+
91+
# Write the ID lists to text files
92+
write_list_file(TRAIN_FILE, TRAIN_IDS)
93+
write_list_file(HOLDOUT_FILE, HOLDOUT_IDS)
94+
write_list_file(VAL_FILE, VAL_IDS)
95+
96+
# Generate TIF files for the train (with target) and val (no target) splits
97+
populate_data(TRAIN_IDS, 'train', with_target=True)
98+
populate_data(HOLDOUT_IDS, 'train', with_target=True)
99+
populate_data(VAL_IDS, 'val', with_target=False)
100+
101+
# zip and compute md5
102+
zip_filename = os.path.join(ROOT, 'dfc25_track2_trainval')
103+
shutil.make_archive(zip_filename, 'zip', ROOT, DATA_DIR)
104+
105+
def md5(fname: str) -> str:
106+
hash_md5 = hashlib.md5()
107+
with open(fname, 'rb') as f:
108+
for chunk in iter(lambda: f.read(4096), b''):
109+
hash_md5.update(chunk)
110+
return hash_md5.hexdigest()
111+
112+
md5sum = md5(zip_filename + '.zip')
113+
print(f'MD5 checksum: {md5sum}')
114+
115+
116+
if __name__ == '__main__':
117+
main()
43.7 KB
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
turkey-earthquake_00000413
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
bata-explosion_00000049
2+
bata-explosion_00000014
3+
bata-explosion_00000047
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
val-disaster_00000001
2+
val-disaster_00000002

tests/datasets/test_bright.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
import os
5+
import shutil
6+
from pathlib import Path
7+
8+
import matplotlib.pyplot as plt
9+
import pytest
10+
import torch
11+
import torch.nn as nn
12+
from _pytest.fixtures import SubRequest
13+
from pytest import MonkeyPatch
14+
15+
from torchgeo.datasets import BRIGHTDFC2025, DatasetNotFoundError
16+
17+
18+
class TestBRIGHTDFC2025:
19+
@pytest.fixture(params=['train', 'val', 'test'])
20+
def dataset(
21+
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
22+
) -> BRIGHTDFC2025:
23+
md5 = '7b0e24d45fb2d9a4f766196702586414'
24+
monkeypatch.setattr(BRIGHTDFC2025, 'md5', md5)
25+
url = os.path.join('tests', 'data', 'bright', 'dfc25_track2_trainval.zip')
26+
monkeypatch.setattr(BRIGHTDFC2025, 'url', url)
27+
root = tmp_path
28+
split = request.param
29+
transforms = nn.Identity()
30+
return BRIGHTDFC2025(root, split, transforms, download=True, checksum=True)
31+
32+
def test_getitem(self, dataset: BRIGHTDFC2025) -> None:
33+
x = dataset[0]
34+
assert isinstance(x, dict)
35+
assert isinstance(x['image_pre'], torch.Tensor)
36+
assert x['image_pre'].shape[0] == 3
37+
assert isinstance(x['image_post'], torch.Tensor)
38+
assert x['image_post'].shape[0] == 3
39+
assert x['image_pre'].shape[-2:] == x['image_post'].shape[-2:]
40+
if dataset.split != 'test':
41+
assert isinstance(x['mask'], torch.Tensor)
42+
assert x['image_pre'].shape[-2:] == x['mask'].shape[-2:]
43+
44+
def test_len(self, dataset: BRIGHTDFC2025) -> None:
45+
if dataset.split == 'train':
46+
assert len(dataset) == 3
47+
elif dataset.split == 'val':
48+
assert len(dataset) == 1
49+
else:
50+
assert len(dataset) == 2
51+
52+
def test_already_downloaded(self, dataset: BRIGHTDFC2025) -> None:
53+
BRIGHTDFC2025(root=dataset.root)
54+
55+
def test_not_yet_extracted(self, tmp_path: Path) -> None:
56+
filename = 'dfc25_track2_trainval.zip'
57+
dir = os.path.join('tests', 'data', 'bright')
58+
shutil.copyfile(
59+
os.path.join(dir, filename), os.path.join(str(tmp_path), filename)
60+
)
61+
BRIGHTDFC2025(root=str(tmp_path))
62+
63+
def test_invalid_split(self) -> None:
64+
with pytest.raises(AssertionError):
65+
BRIGHTDFC2025(split='foo')
66+
67+
def test_not_downloaded(self, tmp_path: Path) -> None:
68+
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
69+
BRIGHTDFC2025(tmp_path)
70+
71+
def test_corrupted(self, tmp_path: Path) -> None:
72+
with open(os.path.join(tmp_path, 'dfc25_track2_trainval.zip'), 'w') as f:
73+
f.write('bad')
74+
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
75+
BRIGHTDFC2025(root=tmp_path, checksum=True)
76+
77+
def test_plot(self, dataset: BRIGHTDFC2025) -> None:
78+
dataset.plot(dataset[0], suptitle='Test')
79+
plt.close()
80+
81+
if dataset.split != 'test':
82+
sample = dataset[0]
83+
sample['prediction'] = torch.clone(sample['mask'])
84+
dataset.plot(sample, suptitle='Prediction')
85+
plt.close()
86+
87+
del sample['mask']
88+
dataset.plot(sample, suptitle='Only Prediction')
89+
plt.close()

torchgeo/datasets/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .benin_cashews import BeninSmallHolderCashews
1212
from .bigearthnet import BigEarthNet
1313
from .biomassters import BioMassters
14+
from .bright import BRIGHTDFC2025
1415
from .cabuar import CaBuAr
1516
from .caffe import CaFFe
1617
from .cbf import CanadianBuildingFootprints
@@ -152,6 +153,7 @@
152153

153154
__all__ = (
154155
'ADVANCE',
156+
'BRIGHTDFC2025',
155157
'CDL',
156158
'COWC',
157159
'DFC2022',

0 commit comments

Comments
 (0)