Skip to content

Add WorldStrat dataset #2558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ Non-geospatial Datasets

:class:`NonGeoDataset` is designed for datasets that lack geospatial information. These datasets can still be combined using :class:`ConcatDataset <torch.utils.data.ConcatDataset>`.

.. csv-table:: C = classification, R = regression, S = semantic segmentation, I = instance segmentation, T = time series, CD = change detection, OD = object detection, IC = image captioning
.. csv-table:: C = classification, R = regression, S = semantic segmentation, I = instance segmentation, T = time series, CD = change detection, OD = object detection, IC = image captioning, SR = super resolution
:widths: 15 7 15 20 12 11 12 15 13
:header-rows: 1
:align: center
Expand Down Expand Up @@ -528,6 +528,10 @@ Western USA Live Fuel Moisture

.. autoclass:: WesternUSALiveFuelMoisture

WorldStrat
^^^^^^^^^^^
.. autoclass:: WorldStrat

xView2
^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/datasets/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,6 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`Vaihingen`_,S,Aerial,-,33,6,"1,281--3,816",0.09,RGB
`VHR-10`_,I,"Google Earth, Vaihingen","MIT",800,10,"358--1,728",0.08--2,RGB
`Western USA Live Fuel Moisture`_,R,"Landsat8, Sentinel-1","CC-BY-NC-ND-4.0",2615,-,-,-,-
`WorldStrat`_,SR,"Sentinel-2,SPOT 6/7","CC-BY-NC-4.0 and CC-BY-4.0",,-,"1.5m and 10m","RGB,MSI"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
`WorldStrat`_,SR,"Sentinel-2,SPOT 6/7","CC-BY-NC-4.0 and CC-BY-4.0",,-,"1.5m and 10m","RGB,MSI"
`WorldStrat`_,SR,"Sentinel-2,SPOT 6/7","CC-BY-NC-4.0 AND CC-BY-4.0",,-,"1.5m and 10m","RGB,MSI"

AND and OR are formally defined license operators in SPDX: https://spdx.dev/learn/handling-license-info/

`xView2`_,CD,Maxar,"CC-BY-NC-SA-4.0","3,732",4,"1,024x1,024",0.8,RGB
`ZueriCrop`_,"I, T",Sentinel-2,CC-BY-NC-4.0,116K,48,24x24,10,MSI
Binary file added tests/data/worldstrat/AOI001/AOI001_pan.tiff
Binary file not shown.
Binary file added tests/data/worldstrat/AOI001/AOI001_ps.tiff
Binary file not shown.
Binary file added tests/data/worldstrat/AOI001/AOI001_rgb.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/worldstrat/AOI001/AOI001_rgbn.tiff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/worldstrat/AOI002/AOI002_pan.tiff
Binary file not shown.
Binary file added tests/data/worldstrat/AOI002/AOI002_ps.tiff
Binary file not shown.
Binary file added tests/data/worldstrat/AOI002/AOI002_rgb.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/worldstrat/AOI002/AOI002_rgbn.tiff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/worldstrat/AOI003/AOI003_pan.tiff
Binary file not shown.
Binary file added tests/data/worldstrat/AOI003/AOI003_ps.tiff
Binary file not shown.
Binary file added tests/data/worldstrat/AOI003/AOI003_rgb.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/worldstrat/AOI003/AOI003_rgbn.tiff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/worldstrat/AOI004/AOI004_pan.tiff
Binary file not shown.
Binary file added tests/data/worldstrat/AOI004/AOI004_ps.tiff
Binary file not shown.
Binary file added tests/data/worldstrat/AOI004/AOI004_rgb.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/worldstrat/AOI004/AOI004_rgbn.tiff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
227 changes: 227 additions & 0 deletions tests/data/worldstrat/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import shutil
import tarfile
from datetime import datetime, timedelta

import numpy as np
import pandas as pd
import rasterio
from PIL import Image


def create_dummy_worldstrat(root: str, img_size: int = 64) -> None:
"""Create dummy WorldStrat dataset."""
os.makedirs(root, exist_ok=True)

tiles = {'train': ['AOI001', 'AOI002'], 'val': ['AOI003'], 'test': ['AOI004']}

metadata = []
split_info = []

# Generate 4 dates for time series
base_date = datetime(2021, 1, 1)
dates = [base_date + timedelta(days=i * 30) for i in range(4)]

for split, tile_list in tiles.items():
for tile in tile_list:
if os.path.exists(os.path.join(root, tile)):
shutil.rmtree(os.path.join(root, tile))
tile_dir = os.path.join(root, tile)
l1c_dir = os.path.join(tile_dir, 'L1C')
l2a_dir = os.path.join(tile_dir, 'L2A')
os.makedirs(l1c_dir, exist_ok=True)
os.makedirs(l2a_dir, exist_ok=True)

# High-res images (single timestep)
hr_ps = np.random.randint(0, 255, (4, img_size, img_size), dtype=np.uint16)
with rasterio.open(
os.path.join(tile_dir, f'{tile}_ps.tiff'),
'w',
driver='GTiff',
height=img_size,
width=img_size,
count=4,
dtype=np.uint16,
transform=rasterio.Affine(1.0, 0, 0, 0, 1.0, 0),
crs=rasterio.crs.CRS.from_epsg(4326),
) as dst:
dst.write(hr_ps)

hr_pan = np.random.randint(0, 255, (1, img_size, img_size), dtype=np.uint16)
with rasterio.open(
os.path.join(tile_dir, f'{tile}_pan.tiff'),
'w',
driver='GTiff',
height=img_size,
width=img_size,
count=1,
dtype=np.uint16,
transform=rasterio.Affine(1.0, 0, 0, 0, 1.0, 0),
crs=rasterio.crs.CRS.from_epsg(4326),
) as dst:
dst.write(hr_pan)

# High-res RGBN (4 channels)
hr_rgbn_png = np.random.randint(
0, 255, (img_size, img_size, 4), dtype=np.uint8
)
rgbn_img = Image.fromarray(hr_rgbn_png, mode='RGBA')
rgbn_img.save(os.path.join(tile, f'{tile}_rgb.png'))

# Low-res RGBN
lr_rgbn = np.random.randint(
0, 255, (4, img_size // 4, img_size // 4), dtype=np.uint16
)
with rasterio.open(
os.path.join(tile_dir, f'{tile}_rgbn.tiff'),
'w',
driver='GTiff',
height=img_size // 8,
width=img_size // 8,
count=4,
dtype=np.uint16,
transform=rasterio.Affine(1.0, 0, 0, 0, 1.0, 0),
crs=rasterio.crs.CRS.from_epsg(4326),
) as dst:
dst.write(lr_rgbn)

# Time series data
for date in dates:
date_str = date.strftime('%Y%m%d')

# L1C (13 bands)
l1c = np.random.randint(
0, 255, (13, img_size // 8, img_size // 8), dtype=np.uint16
)
with rasterio.open(
os.path.join(l1c_dir, f'{tile}_{date_str}_L1C_data.tiff'),
'w',
driver='GTiff',
height=img_size // 8,
width=img_size // 8,
count=13,
dtype=np.uint16,
transform=rasterio.Affine(
9.553533791820828e-05,
0.0,
92.38122406971227,
0.0,
-9.096299266268611e-05,
20.83381094772868,
),
crs=rasterio.crs.CRS.from_epsg(4326),
) as dst:
dst.write(l1c)

# L2A (12 bands)
l2a = np.random.randint(
0, 255, (12, img_size // 8, img_size // 8), dtype=np.uint16
)
with rasterio.open(
os.path.join(l2a_dir, f'{tile}_{date_str}_L2A_data.tiff'),
'w',
driver='GTiff',
height=img_size // 8,
width=img_size // 8,
count=12,
dtype=np.uint16,
transform=rasterio.Affine(
9.553533791820828e-05,
0.0,
92.38122406971227,
0.0,
-9.096299266268611e-05,
20.83381094772868,
),
crs=rasterio.crs.CRS.from_epsg(4326),
) as dst:
dst.write(l2a)

# Metadata with date
for date in dates:
metadata.append(
{
'tile_id': tile,
'lon': np.random.uniform(-180, 180),
'lat': np.random.uniform(-90, 90),
'lowres_date': date.strftime('%Y-%m-%d'),
'highres_date': date.strftime('%Y-%m-%d'),
}
)

split_info.append({'tile': tile, 'split': split})

pd.DataFrame(metadata).to_csv(os.path.join(root, 'metadata.csv'), index=False)
pd.DataFrame(split_info).to_csv(
os.path.join(root, 'stratified_train_val_test_split.csv'), index=False
)


def create_archives(root: str) -> None:
"""Create compressed archives and compute checksums."""
# Create archive structure
archives = {
'hr_dataset.tar.gz': ['_ps.tiff', '_pan.tiff', '_rgbn.tiff', '_rgb.png'],
'lr_dataset_l1c.tar.gz': ['L1C'],
'lr_dataset_l2a.tar.gz': ['L2A'],
}

checksums = {}

# Create each archive
for archive_name, patterns in archives.items():
archive_path = os.path.join(root, archive_name)
with tarfile.open(archive_path, 'w:gz') as tar:
for aoi in os.listdir(root):
aoi_path = os.path.join(root, aoi)
if not os.path.isdir(aoi_path) or aoi.startswith('.'):
continue

# Add files matching patterns
for pattern in patterns:
if pattern.startswith('_'): # High-res files
src = os.path.join(aoi_path, f'{aoi}{pattern}')
if os.path.exists(src):
tar.add(src, os.path.join(aoi, os.path.basename(src)))
else: # L1C/L2A directories
src_dir = os.path.join(aoi_path, pattern)
if os.path.exists(src_dir):
for f in os.listdir(src_dir):
src = os.path.join(src_dir, f)
tar.add(src, os.path.join(aoi, pattern, f))

checksums[archive_name] = compute_md5(archive_path)

# Add CSV files
for csv_file in ['metadata.csv', 'stratified_train_val_test_split.csv']:
checksums[csv_file] = compute_md5(os.path.join(root, csv_file))

# Print checksums in format matching file_info_dict
print('\nfile_info_dict entries:')
for filename, checksum in checksums.items():
name = filename.replace('.tar.gz', '').replace('.csv', '')
print(f"'{name}': {{")
print(f" 'filename': '{filename}',")
print(f" 'md5': '{checksum}',")
print('},')


def compute_md5(filepath: str) -> str:
"""Compute MD5 checksum of a file."""
md5_hash = hashlib.md5()
with open(filepath, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b''):
md5_hash.update(chunk)
return md5_hash.hexdigest()


if __name__ == '__main__':
root_dir = '.'
create_dummy_worldstrat(root_dir)
create_archives(root_dir)
Binary file added tests/data/worldstrat/hr_dataset.tar.gz
Binary file not shown.
Binary file added tests/data/worldstrat/lr_dataset_l1c.tar.gz
Binary file not shown.
Binary file added tests/data/worldstrat/lr_dataset_l2a.tar.gz
Binary file not shown.
17 changes: 17 additions & 0 deletions tests/data/worldstrat/metadata.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
tile_id,lon,lat,lowres_date,highres_date
AOI001,-72.49323312470923,31.416848064048338,2021-01-01,2021-01-01
AOI001,-126.3528881568365,13.34418943549825,2021-01-31,2021-01-31
AOI001,-11.483139921316052,-41.97402005032554,2021-03-02,2021-03-02
AOI001,114.50774674932461,25.58214154752858,2021-04-01,2021-04-01
AOI002,93.77701948238666,-6.858958480524052,2021-01-01,2021-01-01
AOI002,128.6311156852766,-10.428155119947036,2021-01-31,2021-01-31
AOI002,-102.74252094868638,-30.123240390715097,2021-03-02,2021-03-02
AOI002,46.61807533873554,-64.22765197380306,2021-04-01,2021-04-01
AOI003,-11.268454561094615,7.782902072540182,2021-01-01,2021-01-01
AOI003,-87.06842058693209,-60.12054227975885,2021-01-31,2021-01-31
AOI003,138.66972658299403,63.383086985924876,2021-03-02,2021-03-02
AOI003,97.3938715896681,-69.44847561317725,2021-04-01,2021-04-01
AOI004,-151.62046114896762,-7.936329304506657,2021-01-01,2021-01-01
AOI004,118.6429156021232,14.345764808459649,2021-01-31,2021-01-31
AOI004,-179.31698488612918,-14.237879388686977,2021-03-02,2021-03-02
AOI004,70.31208033800476,-57.04972551995917,2021-04-01,2021-04-01
5 changes: 5 additions & 0 deletions tests/data/worldstrat/stratified_train_val_test_split.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
tile,split
AOI001,train
AOI002,train
AOI003,val
AOI004,test
112 changes: 112 additions & 0 deletions tests/datasets/test_worldstrat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch

from torchgeo.datasets import DatasetNotFoundError, WorldStrat


class TestWorldStrat:
@pytest.fixture(params=['train', 'val', 'test'])
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> WorldStrat:
url = os.path.join('tests', 'data', 'worldstrat')

file_info_dict = {
'hr_dataset': {
'url': os.path.join(url, 'hr_dataset.tar.gz'),
'filename': 'hr_dataset.tar.gz',
'md5': 'e395f3357c6d97e5fee1baaffcaa31bd',
},
'lr_dataset_l1c': {
'url': os.path.join(url, 'lr_dataset_l1c.tar.gz'),
'filename': 'lr_dataset_l1c.tar.gz',
'md5': '24db4553ea14b8c8253c13c297d6c862',
},
'lr_dataset_l2a': {
'url': os.path.join(url, 'lr_dataset_l2a.tar.gz'),
'filename': 'lr_dataset_l2a.tar.gz',
'md5': 'a4237eb6fb6a96ef3f52a4e9bf6ee754',
},
'metadata': {
'url': os.path.join(url, 'metadata.csv'),
'filename': 'metadata.csv',
'md5': '6d2ced33b6dc2c25a5c067d34d2c1738',
},
'train_val_test_split': {
'url': os.path.join(url, 'stratified_train_val_test_split.csv'),
'filename': 'stratified_train_val_test_split.csv',
'md5': 'c6941d2c0f044d716ea5f0ab4277cba6',
},
}
monkeypatch.setattr(WorldStrat, 'file_info_dict', file_info_dict)
root = tmp_path
split = request.param
transforms = nn.Identity()
return WorldStrat(
root, split=split, transforms=transforms, download=True, checksum=True
)

def test_getitem(self, dataset: WorldStrat) -> None:
x = dataset[0]
assert isinstance(x, dict)
for modality in dataset.modalities:
assert isinstance(x[f'image_{modality}'], torch.Tensor)

def test_len(self, dataset: WorldStrat) -> None:
if dataset.split == 'train':
assert len(dataset) == 2
else:
assert len(dataset) == 1

def test_already_downloaded(self, dataset: WorldStrat) -> None:
WorldStrat(root=dataset.root)

def test_not_yet_extracted(self, tmp_path: Path) -> None:
file_list = [
'hr_dataset.tar.gz',
'lr_dataset_l1c.tar.gz',
'lr_dataset_l2a.tar.gz',
'metadata.csv',
'stratified_train_val_test_split.csv',
]
dir = os.path.join('tests', 'data', 'worldstrat')
for filename in file_list:
shutil.copyfile(
os.path.join(dir, filename), os.path.join(str(tmp_path), filename)
)
WorldStrat(root=str(tmp_path))

def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
WorldStrat(split='foo')

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
WorldStrat(tmp_path)

def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, 'hr_dataset.tar.gz'), 'w') as f:
f.write('bad')
with pytest.raises(RuntimeError, match='Archive'):
WorldStrat(root=tmp_path, checksum=True)

def test_plot(self, dataset: WorldStrat) -> None:
dataset.plot(dataset[0], suptitle='Test')
plt.close()

def test_pred_plot(self, dataset: WorldStrat) -> None:
x = dataset[0]
x['prediction'] = x['image_hr_rgbn']
dataset.plot(x, suptitle='Test')
plt.close()
Loading