diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8147863f095..e261a1b6f6b 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -200,7 +200,7 @@ Non-geospatial Datasets :class:`NonGeoDataset` is designed for datasets that lack geospatial information. These datasets can still be combined using :class:`ConcatDataset `. -.. csv-table:: C = classification, R = regression, S = semantic segmentation, I = instance segmentation, T = time series, CD = change detection, OD = object detection, IC = image captioning +.. csv-table:: C = classification, R = regression, S = semantic segmentation, I = instance segmentation, T = time series, CD = change detection, OD = object detection, IC = image captioning, SR = super resolution :widths: 15 7 15 20 12 11 12 15 13 :header-rows: 1 :align: center @@ -528,6 +528,10 @@ Western USA Live Fuel Moisture .. autoclass:: WesternUSALiveFuelMoisture +WorldStrat +^^^^^^^^^^^ +.. autoclass:: WorldStrat + xView2 ^^^^^^ diff --git a/docs/api/datasets/non_geo_datasets.csv b/docs/api/datasets/non_geo_datasets.csv index 1defcb032bd..35bdf26060c 100644 --- a/docs/api/datasets/non_geo_datasets.csv +++ b/docs/api/datasets/non_geo_datasets.csv @@ -62,5 +62,6 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `Vaihingen`_,S,Aerial,-,33,6,"1,281--3,816",0.09,RGB `VHR-10`_,I,"Google Earth, Vaihingen","MIT",800,10,"358--1,728",0.08--2,RGB `Western USA Live Fuel Moisture`_,R,"Landsat8, Sentinel-1","CC-BY-NC-ND-4.0",2615,-,-,-,- +`WorldStrat`_,SR,"Sentinel-2,SPOT 6/7","CC-BY-NC-4.0 and CC-BY-4.0",,-,"1.5m and 10m","RGB,MSI" `xView2`_,CD,Maxar,"CC-BY-NC-SA-4.0","3,732",4,"1,024x1,024",0.8,RGB `ZueriCrop`_,"I, T",Sentinel-2,CC-BY-NC-4.0,116K,48,24x24,10,MSI diff --git a/tests/data/worldstrat/AOI001/AOI001_pan.tiff b/tests/data/worldstrat/AOI001/AOI001_pan.tiff new file mode 100644 index 00000000000..ccdf0909f22 Binary files /dev/null and b/tests/data/worldstrat/AOI001/AOI001_pan.tiff differ diff --git a/tests/data/worldstrat/AOI001/AOI001_ps.tiff b/tests/data/worldstrat/AOI001/AOI001_ps.tiff new file mode 100644 index 00000000000..fb463956a82 Binary files /dev/null and b/tests/data/worldstrat/AOI001/AOI001_ps.tiff differ diff --git a/tests/data/worldstrat/AOI001/AOI001_rgb.png b/tests/data/worldstrat/AOI001/AOI001_rgb.png new file mode 100644 index 00000000000..f8cdd12be6d Binary files /dev/null and b/tests/data/worldstrat/AOI001/AOI001_rgb.png differ diff --git a/tests/data/worldstrat/AOI001/AOI001_rgbn.tiff b/tests/data/worldstrat/AOI001/AOI001_rgbn.tiff new file mode 100644 index 00000000000..e19a3b13fc5 Binary files /dev/null and b/tests/data/worldstrat/AOI001/AOI001_rgbn.tiff differ diff --git a/tests/data/worldstrat/AOI001/L1C/AOI001_20210101_L1C_data.tiff b/tests/data/worldstrat/AOI001/L1C/AOI001_20210101_L1C_data.tiff new file mode 100644 index 00000000000..ece14a2dd38 Binary files /dev/null and b/tests/data/worldstrat/AOI001/L1C/AOI001_20210101_L1C_data.tiff differ diff --git a/tests/data/worldstrat/AOI001/L1C/AOI001_20210131_L1C_data.tiff b/tests/data/worldstrat/AOI001/L1C/AOI001_20210131_L1C_data.tiff new file mode 100644 index 00000000000..f7b11b9314d Binary files /dev/null and b/tests/data/worldstrat/AOI001/L1C/AOI001_20210131_L1C_data.tiff differ diff --git a/tests/data/worldstrat/AOI001/L1C/AOI001_20210302_L1C_data.tiff b/tests/data/worldstrat/AOI001/L1C/AOI001_20210302_L1C_data.tiff new file mode 100644 index 00000000000..f6c2e8f53c0 Binary files /dev/null and b/tests/data/worldstrat/AOI001/L1C/AOI001_20210302_L1C_data.tiff differ diff --git a/tests/data/worldstrat/AOI001/L1C/AOI001_20210401_L1C_data.tiff b/tests/data/worldstrat/AOI001/L1C/AOI001_20210401_L1C_data.tiff new file mode 100644 index 00000000000..af7b51bb8c9 Binary files /dev/null and b/tests/data/worldstrat/AOI001/L1C/AOI001_20210401_L1C_data.tiff differ diff --git a/tests/data/worldstrat/AOI001/L2A/AOI001_20210101_L2A_data.tiff b/tests/data/worldstrat/AOI001/L2A/AOI001_20210101_L2A_data.tiff new file mode 100644 index 00000000000..1a62b873f95 Binary files /dev/null and b/tests/data/worldstrat/AOI001/L2A/AOI001_20210101_L2A_data.tiff differ diff --git a/tests/data/worldstrat/AOI001/L2A/AOI001_20210131_L2A_data.tiff b/tests/data/worldstrat/AOI001/L2A/AOI001_20210131_L2A_data.tiff new file mode 100644 index 00000000000..b14dc966b47 Binary files /dev/null and b/tests/data/worldstrat/AOI001/L2A/AOI001_20210131_L2A_data.tiff differ diff --git a/tests/data/worldstrat/AOI001/L2A/AOI001_20210302_L2A_data.tiff b/tests/data/worldstrat/AOI001/L2A/AOI001_20210302_L2A_data.tiff new file mode 100644 index 00000000000..a67b086597f Binary files /dev/null and b/tests/data/worldstrat/AOI001/L2A/AOI001_20210302_L2A_data.tiff differ diff --git a/tests/data/worldstrat/AOI001/L2A/AOI001_20210401_L2A_data.tiff b/tests/data/worldstrat/AOI001/L2A/AOI001_20210401_L2A_data.tiff new file mode 100644 index 00000000000..b74116e1c85 Binary files /dev/null and b/tests/data/worldstrat/AOI001/L2A/AOI001_20210401_L2A_data.tiff differ diff --git a/tests/data/worldstrat/AOI002/AOI002_pan.tiff b/tests/data/worldstrat/AOI002/AOI002_pan.tiff new file mode 100644 index 00000000000..48db74db1ee Binary files /dev/null and b/tests/data/worldstrat/AOI002/AOI002_pan.tiff differ diff --git a/tests/data/worldstrat/AOI002/AOI002_ps.tiff b/tests/data/worldstrat/AOI002/AOI002_ps.tiff new file mode 100644 index 00000000000..cbca4a82798 Binary files /dev/null and b/tests/data/worldstrat/AOI002/AOI002_ps.tiff differ diff --git a/tests/data/worldstrat/AOI002/AOI002_rgb.png b/tests/data/worldstrat/AOI002/AOI002_rgb.png new file mode 100644 index 00000000000..2c0434f1a72 Binary files /dev/null and b/tests/data/worldstrat/AOI002/AOI002_rgb.png differ diff --git a/tests/data/worldstrat/AOI002/AOI002_rgbn.tiff b/tests/data/worldstrat/AOI002/AOI002_rgbn.tiff new file mode 100644 index 00000000000..a4507beea1c Binary files /dev/null and b/tests/data/worldstrat/AOI002/AOI002_rgbn.tiff differ diff --git a/tests/data/worldstrat/AOI002/L1C/AOI002_20210101_L1C_data.tiff b/tests/data/worldstrat/AOI002/L1C/AOI002_20210101_L1C_data.tiff new file mode 100644 index 00000000000..46e24bd0dfe Binary files /dev/null and b/tests/data/worldstrat/AOI002/L1C/AOI002_20210101_L1C_data.tiff differ diff --git a/tests/data/worldstrat/AOI002/L1C/AOI002_20210131_L1C_data.tiff b/tests/data/worldstrat/AOI002/L1C/AOI002_20210131_L1C_data.tiff new file mode 100644 index 00000000000..60303d4ca9b Binary files /dev/null and b/tests/data/worldstrat/AOI002/L1C/AOI002_20210131_L1C_data.tiff differ diff --git a/tests/data/worldstrat/AOI002/L1C/AOI002_20210302_L1C_data.tiff b/tests/data/worldstrat/AOI002/L1C/AOI002_20210302_L1C_data.tiff new file mode 100644 index 00000000000..345905a2062 Binary files /dev/null and b/tests/data/worldstrat/AOI002/L1C/AOI002_20210302_L1C_data.tiff differ diff --git a/tests/data/worldstrat/AOI002/L1C/AOI002_20210401_L1C_data.tiff b/tests/data/worldstrat/AOI002/L1C/AOI002_20210401_L1C_data.tiff new file mode 100644 index 00000000000..20708544fb3 Binary files /dev/null and b/tests/data/worldstrat/AOI002/L1C/AOI002_20210401_L1C_data.tiff differ diff --git a/tests/data/worldstrat/AOI002/L2A/AOI002_20210101_L2A_data.tiff b/tests/data/worldstrat/AOI002/L2A/AOI002_20210101_L2A_data.tiff new file mode 100644 index 00000000000..39cc35ebdd1 Binary files /dev/null and b/tests/data/worldstrat/AOI002/L2A/AOI002_20210101_L2A_data.tiff differ diff --git a/tests/data/worldstrat/AOI002/L2A/AOI002_20210131_L2A_data.tiff b/tests/data/worldstrat/AOI002/L2A/AOI002_20210131_L2A_data.tiff new file mode 100644 index 00000000000..115b076db92 Binary files /dev/null and b/tests/data/worldstrat/AOI002/L2A/AOI002_20210131_L2A_data.tiff differ diff --git a/tests/data/worldstrat/AOI002/L2A/AOI002_20210302_L2A_data.tiff b/tests/data/worldstrat/AOI002/L2A/AOI002_20210302_L2A_data.tiff new file mode 100644 index 00000000000..9f0e9cc92c6 Binary files /dev/null and b/tests/data/worldstrat/AOI002/L2A/AOI002_20210302_L2A_data.tiff differ diff --git a/tests/data/worldstrat/AOI002/L2A/AOI002_20210401_L2A_data.tiff b/tests/data/worldstrat/AOI002/L2A/AOI002_20210401_L2A_data.tiff new file mode 100644 index 00000000000..f1d25955350 Binary files /dev/null and b/tests/data/worldstrat/AOI002/L2A/AOI002_20210401_L2A_data.tiff differ diff --git a/tests/data/worldstrat/AOI003/AOI003_pan.tiff b/tests/data/worldstrat/AOI003/AOI003_pan.tiff new file mode 100644 index 00000000000..0a679849301 Binary files /dev/null and b/tests/data/worldstrat/AOI003/AOI003_pan.tiff differ diff --git a/tests/data/worldstrat/AOI003/AOI003_ps.tiff b/tests/data/worldstrat/AOI003/AOI003_ps.tiff new file mode 100644 index 00000000000..60b909611a3 Binary files /dev/null and b/tests/data/worldstrat/AOI003/AOI003_ps.tiff differ diff --git a/tests/data/worldstrat/AOI003/AOI003_rgb.png b/tests/data/worldstrat/AOI003/AOI003_rgb.png new file mode 100644 index 00000000000..9cbf7883cb3 Binary files /dev/null and b/tests/data/worldstrat/AOI003/AOI003_rgb.png differ diff --git a/tests/data/worldstrat/AOI003/AOI003_rgbn.tiff b/tests/data/worldstrat/AOI003/AOI003_rgbn.tiff new file mode 100644 index 00000000000..866b0669f2a Binary files /dev/null and b/tests/data/worldstrat/AOI003/AOI003_rgbn.tiff differ diff --git a/tests/data/worldstrat/AOI003/L1C/AOI003_20210101_L1C_data.tiff b/tests/data/worldstrat/AOI003/L1C/AOI003_20210101_L1C_data.tiff new file mode 100644 index 00000000000..c178a5c566d Binary files /dev/null and b/tests/data/worldstrat/AOI003/L1C/AOI003_20210101_L1C_data.tiff differ diff --git a/tests/data/worldstrat/AOI003/L1C/AOI003_20210131_L1C_data.tiff b/tests/data/worldstrat/AOI003/L1C/AOI003_20210131_L1C_data.tiff new file mode 100644 index 00000000000..d85d8eefa9c Binary files /dev/null and b/tests/data/worldstrat/AOI003/L1C/AOI003_20210131_L1C_data.tiff differ diff --git a/tests/data/worldstrat/AOI003/L1C/AOI003_20210302_L1C_data.tiff b/tests/data/worldstrat/AOI003/L1C/AOI003_20210302_L1C_data.tiff new file mode 100644 index 00000000000..9386eda9823 Binary files /dev/null and b/tests/data/worldstrat/AOI003/L1C/AOI003_20210302_L1C_data.tiff differ diff --git a/tests/data/worldstrat/AOI003/L1C/AOI003_20210401_L1C_data.tiff b/tests/data/worldstrat/AOI003/L1C/AOI003_20210401_L1C_data.tiff new file mode 100644 index 00000000000..0189419a101 Binary files /dev/null and b/tests/data/worldstrat/AOI003/L1C/AOI003_20210401_L1C_data.tiff differ diff --git a/tests/data/worldstrat/AOI003/L2A/AOI003_20210101_L2A_data.tiff b/tests/data/worldstrat/AOI003/L2A/AOI003_20210101_L2A_data.tiff new file mode 100644 index 00000000000..3591610c2f0 Binary files /dev/null and b/tests/data/worldstrat/AOI003/L2A/AOI003_20210101_L2A_data.tiff differ diff --git a/tests/data/worldstrat/AOI003/L2A/AOI003_20210131_L2A_data.tiff b/tests/data/worldstrat/AOI003/L2A/AOI003_20210131_L2A_data.tiff new file mode 100644 index 00000000000..cacce382a55 Binary files /dev/null and b/tests/data/worldstrat/AOI003/L2A/AOI003_20210131_L2A_data.tiff differ diff --git a/tests/data/worldstrat/AOI003/L2A/AOI003_20210302_L2A_data.tiff b/tests/data/worldstrat/AOI003/L2A/AOI003_20210302_L2A_data.tiff new file mode 100644 index 00000000000..fe5cc2eb455 Binary files /dev/null and b/tests/data/worldstrat/AOI003/L2A/AOI003_20210302_L2A_data.tiff differ diff --git a/tests/data/worldstrat/AOI003/L2A/AOI003_20210401_L2A_data.tiff b/tests/data/worldstrat/AOI003/L2A/AOI003_20210401_L2A_data.tiff new file mode 100644 index 00000000000..0baf06cb3a3 Binary files /dev/null and b/tests/data/worldstrat/AOI003/L2A/AOI003_20210401_L2A_data.tiff differ diff --git a/tests/data/worldstrat/AOI004/AOI004_pan.tiff b/tests/data/worldstrat/AOI004/AOI004_pan.tiff new file mode 100644 index 00000000000..b9574c1bf45 Binary files /dev/null and b/tests/data/worldstrat/AOI004/AOI004_pan.tiff differ diff --git a/tests/data/worldstrat/AOI004/AOI004_ps.tiff b/tests/data/worldstrat/AOI004/AOI004_ps.tiff new file mode 100644 index 00000000000..c0ae85bb0e4 Binary files /dev/null and b/tests/data/worldstrat/AOI004/AOI004_ps.tiff differ diff --git a/tests/data/worldstrat/AOI004/AOI004_rgb.png b/tests/data/worldstrat/AOI004/AOI004_rgb.png new file mode 100644 index 00000000000..3fe2ba989a3 Binary files /dev/null and b/tests/data/worldstrat/AOI004/AOI004_rgb.png differ diff --git a/tests/data/worldstrat/AOI004/AOI004_rgbn.tiff b/tests/data/worldstrat/AOI004/AOI004_rgbn.tiff new file mode 100644 index 00000000000..087a7678b6d Binary files /dev/null and b/tests/data/worldstrat/AOI004/AOI004_rgbn.tiff differ diff --git a/tests/data/worldstrat/AOI004/L1C/AOI004_20210101_L1C_data.tiff b/tests/data/worldstrat/AOI004/L1C/AOI004_20210101_L1C_data.tiff new file mode 100644 index 00000000000..2d3f22f2514 Binary files /dev/null and b/tests/data/worldstrat/AOI004/L1C/AOI004_20210101_L1C_data.tiff differ diff --git a/tests/data/worldstrat/AOI004/L1C/AOI004_20210131_L1C_data.tiff b/tests/data/worldstrat/AOI004/L1C/AOI004_20210131_L1C_data.tiff new file mode 100644 index 00000000000..69aeada28d8 Binary files /dev/null and b/tests/data/worldstrat/AOI004/L1C/AOI004_20210131_L1C_data.tiff differ diff --git a/tests/data/worldstrat/AOI004/L1C/AOI004_20210302_L1C_data.tiff b/tests/data/worldstrat/AOI004/L1C/AOI004_20210302_L1C_data.tiff new file mode 100644 index 00000000000..e1024cb539f Binary files /dev/null and b/tests/data/worldstrat/AOI004/L1C/AOI004_20210302_L1C_data.tiff differ diff --git a/tests/data/worldstrat/AOI004/L1C/AOI004_20210401_L1C_data.tiff b/tests/data/worldstrat/AOI004/L1C/AOI004_20210401_L1C_data.tiff new file mode 100644 index 00000000000..c92a9a34c8a Binary files /dev/null and b/tests/data/worldstrat/AOI004/L1C/AOI004_20210401_L1C_data.tiff differ diff --git a/tests/data/worldstrat/AOI004/L2A/AOI004_20210101_L2A_data.tiff b/tests/data/worldstrat/AOI004/L2A/AOI004_20210101_L2A_data.tiff new file mode 100644 index 00000000000..0a37ad569d2 Binary files /dev/null and b/tests/data/worldstrat/AOI004/L2A/AOI004_20210101_L2A_data.tiff differ diff --git a/tests/data/worldstrat/AOI004/L2A/AOI004_20210131_L2A_data.tiff b/tests/data/worldstrat/AOI004/L2A/AOI004_20210131_L2A_data.tiff new file mode 100644 index 00000000000..23d6bab1746 Binary files /dev/null and b/tests/data/worldstrat/AOI004/L2A/AOI004_20210131_L2A_data.tiff differ diff --git a/tests/data/worldstrat/AOI004/L2A/AOI004_20210302_L2A_data.tiff b/tests/data/worldstrat/AOI004/L2A/AOI004_20210302_L2A_data.tiff new file mode 100644 index 00000000000..b5602dc30f0 Binary files /dev/null and b/tests/data/worldstrat/AOI004/L2A/AOI004_20210302_L2A_data.tiff differ diff --git a/tests/data/worldstrat/AOI004/L2A/AOI004_20210401_L2A_data.tiff b/tests/data/worldstrat/AOI004/L2A/AOI004_20210401_L2A_data.tiff new file mode 100644 index 00000000000..82cd8de7228 Binary files /dev/null and b/tests/data/worldstrat/AOI004/L2A/AOI004_20210401_L2A_data.tiff differ diff --git a/tests/data/worldstrat/data.py b/tests/data/worldstrat/data.py new file mode 100644 index 00000000000..4d5a208e342 --- /dev/null +++ b/tests/data/worldstrat/data.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import shutil +import tarfile +from datetime import datetime, timedelta + +import numpy as np +import pandas as pd +import rasterio +from PIL import Image + + +def create_dummy_worldstrat(root: str, img_size: int = 64) -> None: + """Create dummy WorldStrat dataset.""" + os.makedirs(root, exist_ok=True) + + tiles = {'train': ['AOI001', 'AOI002'], 'val': ['AOI003'], 'test': ['AOI004']} + + metadata = [] + split_info = [] + + # Generate 4 dates for time series + base_date = datetime(2021, 1, 1) + dates = [base_date + timedelta(days=i * 30) for i in range(4)] + + for split, tile_list in tiles.items(): + for tile in tile_list: + if os.path.exists(os.path.join(root, tile)): + shutil.rmtree(os.path.join(root, tile)) + tile_dir = os.path.join(root, tile) + l1c_dir = os.path.join(tile_dir, 'L1C') + l2a_dir = os.path.join(tile_dir, 'L2A') + os.makedirs(l1c_dir, exist_ok=True) + os.makedirs(l2a_dir, exist_ok=True) + + # High-res images (single timestep) + hr_ps = np.random.randint(0, 255, (4, img_size, img_size), dtype=np.uint16) + with rasterio.open( + os.path.join(tile_dir, f'{tile}_ps.tiff'), + 'w', + driver='GTiff', + height=img_size, + width=img_size, + count=4, + dtype=np.uint16, + transform=rasterio.Affine(1.0, 0, 0, 0, 1.0, 0), + crs=rasterio.crs.CRS.from_epsg(4326), + ) as dst: + dst.write(hr_ps) + + hr_pan = np.random.randint(0, 255, (1, img_size, img_size), dtype=np.uint16) + with rasterio.open( + os.path.join(tile_dir, f'{tile}_pan.tiff'), + 'w', + driver='GTiff', + height=img_size, + width=img_size, + count=1, + dtype=np.uint16, + transform=rasterio.Affine(1.0, 0, 0, 0, 1.0, 0), + crs=rasterio.crs.CRS.from_epsg(4326), + ) as dst: + dst.write(hr_pan) + + # High-res RGBN (4 channels) + hr_rgbn_png = np.random.randint( + 0, 255, (img_size, img_size, 4), dtype=np.uint8 + ) + rgbn_img = Image.fromarray(hr_rgbn_png, mode='RGBA') + rgbn_img.save(os.path.join(tile, f'{tile}_rgb.png')) + + # Low-res RGBN + lr_rgbn = np.random.randint( + 0, 255, (4, img_size // 4, img_size // 4), dtype=np.uint16 + ) + with rasterio.open( + os.path.join(tile_dir, f'{tile}_rgbn.tiff'), + 'w', + driver='GTiff', + height=img_size // 8, + width=img_size // 8, + count=4, + dtype=np.uint16, + transform=rasterio.Affine(1.0, 0, 0, 0, 1.0, 0), + crs=rasterio.crs.CRS.from_epsg(4326), + ) as dst: + dst.write(lr_rgbn) + + # Time series data + for date in dates: + date_str = date.strftime('%Y%m%d') + + # L1C (13 bands) + l1c = np.random.randint( + 0, 255, (13, img_size // 8, img_size // 8), dtype=np.uint16 + ) + with rasterio.open( + os.path.join(l1c_dir, f'{tile}_{date_str}_L1C_data.tiff'), + 'w', + driver='GTiff', + height=img_size // 8, + width=img_size // 8, + count=13, + dtype=np.uint16, + transform=rasterio.Affine( + 9.553533791820828e-05, + 0.0, + 92.38122406971227, + 0.0, + -9.096299266268611e-05, + 20.83381094772868, + ), + crs=rasterio.crs.CRS.from_epsg(4326), + ) as dst: + dst.write(l1c) + + # L2A (12 bands) + l2a = np.random.randint( + 0, 255, (12, img_size // 8, img_size // 8), dtype=np.uint16 + ) + with rasterio.open( + os.path.join(l2a_dir, f'{tile}_{date_str}_L2A_data.tiff'), + 'w', + driver='GTiff', + height=img_size // 8, + width=img_size // 8, + count=12, + dtype=np.uint16, + transform=rasterio.Affine( + 9.553533791820828e-05, + 0.0, + 92.38122406971227, + 0.0, + -9.096299266268611e-05, + 20.83381094772868, + ), + crs=rasterio.crs.CRS.from_epsg(4326), + ) as dst: + dst.write(l2a) + + # Metadata with date + for date in dates: + metadata.append( + { + 'tile_id': tile, + 'lon': np.random.uniform(-180, 180), + 'lat': np.random.uniform(-90, 90), + 'lowres_date': date.strftime('%Y-%m-%d'), + 'highres_date': date.strftime('%Y-%m-%d'), + } + ) + + split_info.append({'tile': tile, 'split': split}) + + pd.DataFrame(metadata).to_csv(os.path.join(root, 'metadata.csv'), index=False) + pd.DataFrame(split_info).to_csv( + os.path.join(root, 'stratified_train_val_test_split.csv'), index=False + ) + + +def create_archives(root: str) -> None: + """Create compressed archives and compute checksums.""" + # Create archive structure + archives = { + 'hr_dataset.tar.gz': ['_ps.tiff', '_pan.tiff', '_rgbn.tiff', '_rgb.png'], + 'lr_dataset_l1c.tar.gz': ['L1C'], + 'lr_dataset_l2a.tar.gz': ['L2A'], + } + + checksums = {} + + # Create each archive + for archive_name, patterns in archives.items(): + archive_path = os.path.join(root, archive_name) + with tarfile.open(archive_path, 'w:gz') as tar: + for aoi in os.listdir(root): + aoi_path = os.path.join(root, aoi) + if not os.path.isdir(aoi_path) or aoi.startswith('.'): + continue + + # Add files matching patterns + for pattern in patterns: + if pattern.startswith('_'): # High-res files + src = os.path.join(aoi_path, f'{aoi}{pattern}') + if os.path.exists(src): + tar.add(src, os.path.join(aoi, os.path.basename(src))) + else: # L1C/L2A directories + src_dir = os.path.join(aoi_path, pattern) + if os.path.exists(src_dir): + for f in os.listdir(src_dir): + src = os.path.join(src_dir, f) + tar.add(src, os.path.join(aoi, pattern, f)) + + checksums[archive_name] = compute_md5(archive_path) + + # Add CSV files + for csv_file in ['metadata.csv', 'stratified_train_val_test_split.csv']: + checksums[csv_file] = compute_md5(os.path.join(root, csv_file)) + + # Print checksums in format matching file_info_dict + print('\nfile_info_dict entries:') + for filename, checksum in checksums.items(): + name = filename.replace('.tar.gz', '').replace('.csv', '') + print(f"'{name}': {{") + print(f" 'filename': '{filename}',") + print(f" 'md5': '{checksum}',") + print('},') + + +def compute_md5(filepath: str) -> str: + """Compute MD5 checksum of a file.""" + md5_hash = hashlib.md5() + with open(filepath, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b''): + md5_hash.update(chunk) + return md5_hash.hexdigest() + + +if __name__ == '__main__': + root_dir = '.' + create_dummy_worldstrat(root_dir) + create_archives(root_dir) diff --git a/tests/data/worldstrat/hr_dataset.tar.gz b/tests/data/worldstrat/hr_dataset.tar.gz new file mode 100644 index 00000000000..7b4d8aecd27 Binary files /dev/null and b/tests/data/worldstrat/hr_dataset.tar.gz differ diff --git a/tests/data/worldstrat/lr_dataset_l1c.tar.gz b/tests/data/worldstrat/lr_dataset_l1c.tar.gz new file mode 100644 index 00000000000..3ed83f57e33 Binary files /dev/null and b/tests/data/worldstrat/lr_dataset_l1c.tar.gz differ diff --git a/tests/data/worldstrat/lr_dataset_l2a.tar.gz b/tests/data/worldstrat/lr_dataset_l2a.tar.gz new file mode 100644 index 00000000000..c9d10e32f69 Binary files /dev/null and b/tests/data/worldstrat/lr_dataset_l2a.tar.gz differ diff --git a/tests/data/worldstrat/metadata.csv b/tests/data/worldstrat/metadata.csv new file mode 100644 index 00000000000..9a736b2786c --- /dev/null +++ b/tests/data/worldstrat/metadata.csv @@ -0,0 +1,17 @@ +tile_id,lon,lat,lowres_date,highres_date +AOI001,-72.49323312470923,31.416848064048338,2021-01-01,2021-01-01 +AOI001,-126.3528881568365,13.34418943549825,2021-01-31,2021-01-31 +AOI001,-11.483139921316052,-41.97402005032554,2021-03-02,2021-03-02 +AOI001,114.50774674932461,25.58214154752858,2021-04-01,2021-04-01 +AOI002,93.77701948238666,-6.858958480524052,2021-01-01,2021-01-01 +AOI002,128.6311156852766,-10.428155119947036,2021-01-31,2021-01-31 +AOI002,-102.74252094868638,-30.123240390715097,2021-03-02,2021-03-02 +AOI002,46.61807533873554,-64.22765197380306,2021-04-01,2021-04-01 +AOI003,-11.268454561094615,7.782902072540182,2021-01-01,2021-01-01 +AOI003,-87.06842058693209,-60.12054227975885,2021-01-31,2021-01-31 +AOI003,138.66972658299403,63.383086985924876,2021-03-02,2021-03-02 +AOI003,97.3938715896681,-69.44847561317725,2021-04-01,2021-04-01 +AOI004,-151.62046114896762,-7.936329304506657,2021-01-01,2021-01-01 +AOI004,118.6429156021232,14.345764808459649,2021-01-31,2021-01-31 +AOI004,-179.31698488612918,-14.237879388686977,2021-03-02,2021-03-02 +AOI004,70.31208033800476,-57.04972551995917,2021-04-01,2021-04-01 diff --git a/tests/data/worldstrat/stratified_train_val_test_split.csv b/tests/data/worldstrat/stratified_train_val_test_split.csv new file mode 100644 index 00000000000..8677704bd78 --- /dev/null +++ b/tests/data/worldstrat/stratified_train_val_test_split.csv @@ -0,0 +1,5 @@ +tile,split +AOI001,train +AOI002,train +AOI003,val +AOI004,test diff --git a/tests/datasets/test_worldstrat.py b/tests/datasets/test_worldstrat.py new file mode 100644 index 00000000000..803d18eb870 --- /dev/null +++ b/tests/datasets/test_worldstrat.py @@ -0,0 +1,112 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch + +from torchgeo.datasets import DatasetNotFoundError, WorldStrat + + +class TestWorldStrat: + @pytest.fixture(params=['train', 'val', 'test']) + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> WorldStrat: + url = os.path.join('tests', 'data', 'worldstrat') + + file_info_dict = { + 'hr_dataset': { + 'url': os.path.join(url, 'hr_dataset.tar.gz'), + 'filename': 'hr_dataset.tar.gz', + 'md5': 'e395f3357c6d97e5fee1baaffcaa31bd', + }, + 'lr_dataset_l1c': { + 'url': os.path.join(url, 'lr_dataset_l1c.tar.gz'), + 'filename': 'lr_dataset_l1c.tar.gz', + 'md5': '24db4553ea14b8c8253c13c297d6c862', + }, + 'lr_dataset_l2a': { + 'url': os.path.join(url, 'lr_dataset_l2a.tar.gz'), + 'filename': 'lr_dataset_l2a.tar.gz', + 'md5': 'a4237eb6fb6a96ef3f52a4e9bf6ee754', + }, + 'metadata': { + 'url': os.path.join(url, 'metadata.csv'), + 'filename': 'metadata.csv', + 'md5': '6d2ced33b6dc2c25a5c067d34d2c1738', + }, + 'train_val_test_split': { + 'url': os.path.join(url, 'stratified_train_val_test_split.csv'), + 'filename': 'stratified_train_val_test_split.csv', + 'md5': 'c6941d2c0f044d716ea5f0ab4277cba6', + }, + } + monkeypatch.setattr(WorldStrat, 'file_info_dict', file_info_dict) + root = tmp_path + split = request.param + transforms = nn.Identity() + return WorldStrat( + root, split=split, transforms=transforms, download=True, checksum=True + ) + + def test_getitem(self, dataset: WorldStrat) -> None: + x = dataset[0] + assert isinstance(x, dict) + for modality in dataset.modalities: + assert isinstance(x[f'image_{modality}'], torch.Tensor) + + def test_len(self, dataset: WorldStrat) -> None: + if dataset.split == 'train': + assert len(dataset) == 2 + else: + assert len(dataset) == 1 + + def test_already_downloaded(self, dataset: WorldStrat) -> None: + WorldStrat(root=dataset.root) + + def test_not_yet_extracted(self, tmp_path: Path) -> None: + file_list = [ + 'hr_dataset.tar.gz', + 'lr_dataset_l1c.tar.gz', + 'lr_dataset_l2a.tar.gz', + 'metadata.csv', + 'stratified_train_val_test_split.csv', + ] + dir = os.path.join('tests', 'data', 'worldstrat') + for filename in file_list: + shutil.copyfile( + os.path.join(dir, filename), os.path.join(str(tmp_path), filename) + ) + WorldStrat(root=str(tmp_path)) + + def test_invalid_split(self) -> None: + with pytest.raises(AssertionError): + WorldStrat(split='foo') + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + WorldStrat(tmp_path) + + def test_corrupted(self, tmp_path: Path) -> None: + with open(os.path.join(tmp_path, 'hr_dataset.tar.gz'), 'w') as f: + f.write('bad') + with pytest.raises(RuntimeError, match='Archive'): + WorldStrat(root=tmp_path, checksum=True) + + def test_plot(self, dataset: WorldStrat) -> None: + dataset.plot(dataset[0], suptitle='Test') + plt.close() + + def test_pred_plot(self, dataset: WorldStrat) -> None: + x = dataset[0] + x['prediction'] = x['image_hr_rgbn'] + dataset.plot(x, suptitle='Test') + plt.close() diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 1d644c6fc69..b930152eb89 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -149,6 +149,7 @@ from .vaihingen import Vaihingen2D from .vhr10 import VHR10 from .western_usa_live_fuel_moisture import WesternUSALiveFuelMoisture +from .worldstrat import WorldStrat from .xview import XView2 from .zuericrop import ZueriCrop @@ -293,6 +294,7 @@ 'Vaihingen2D', 'VectorDataset', 'WesternUSALiveFuelMoisture', + 'WorldStrat', 'XView2', 'ZueriCrop', 'concat_samples', diff --git a/torchgeo/datasets/worldstrat.py b/torchgeo/datasets/worldstrat.py new file mode 100644 index 00000000000..7bfb612dd1b --- /dev/null +++ b/torchgeo/datasets/worldstrat.py @@ -0,0 +1,368 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""WorldStrat Dataset.""" + +import os +from collections.abc import Callable, Sequence +from glob import glob +from typing import ClassVar + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import rasterio +import torch +from matplotlib.figure import Figure +from PIL import Image +from torch import Tensor + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import ( + Path, + array_to_tensor, + check_integrity, + download_and_extract_archive, + download_url, + extract_archive, + percentile_normalization, +) + + +class WorldStrat(NonGeoDataset): + """WorldStrat dataset. + + `WorldStrat `_ is a multi-modal dataset covering nearly 10,000km2 of matched high and low resolution + satellite imagery across the globe. High-resolution SPOT 6/7 imagery comes at a resolution of 1.5m/pixel and is matched with a time-series + of Sentinel 2 data. + + Dataset features: + + * High resolution (1.5m/pixel) Airbus SPOT 6/7 imagery with RGBN channels + * Low resolution (8x lower) Sentinel 2 L1C and L2A + * globally distributed areas of interest around the world + + + Dataset format: + + * pixel dimensions vary across AOI tiles + * all modalities are 'tif' files except for 'hr_rgbn' which is 'png' + * 'hr_ps', 'hr_pan', 'hr_rgbn' are high resolution data + * 'lr_rgbn' is low resolution data and roughly 4x lower resolution than 'hr_rgbn' + * 'l1c' and 'l2a' are Sentinel-2 data with 13 and 12 bands respectively and roughly 8x lower resolution than 'hr_rgbn' + + If you use this dataset in your research, please cite the following entries: + + * https://zenodo.org/records/6810792 + * https://arxiv.org/abs/2207.06418 + + .. versionadded:: 0.7 + """ + + modality_titles: ClassVar[dict[str, str]] = { + 'l1c': 'Sentinel-2 L1C', + 'l2a': 'Sentinel-2 L2A', + 'lr_rgbn': 'Low-res RGBN', + 'hr_ps': 'High-res PS', + 'hr_pan': 'High-res PAN', + 'hr_rgbn': 'High-res RGB', + } + + all_modalities = ('hr_ps', 'hr_pan', 'hr_rgbn', 'lr_rgbn', 'l1c', 'l2a') + + valid_splits = ('train', 'val', 'test') + + file_info_dict: ClassVar[dict[str, dict[str, str]]] = { + 'hr_dataset': { + 'url': 'https://zenodo.org/records/6810792/files/hr_dataset.tar.gz?download=1', + 'filename': 'hr_dataset.tar.gz', + 'md5': 'ca7167334006f3c17f9071f14c435335', + }, + 'lr_dataset_l1c': { + 'url': 'https://zenodo.org/records/6810792/files/lr_dataset_l1c.tar.gz?download=1', + 'filename': 'lr_dataset_l1c.tar.gz', + 'md5': 'd2dcafa207b1e1bc6c754607f15e9ed6', + }, + 'lr_dataset_l2a': { + 'url': 'https://zenodo.org/records/6810792/files/lr_dataset_l2a.tar.gz?download=1', + 'filename': 'lr_dataset_l2a.tar.gz', + 'md5': '8cfc6a477cee9e9cd8b20ea27227de65', + }, + 'metadata': { + 'url': 'https://zenodo.org/records/6810792/files/metadata.csv?download=1', + 'filename': 'metadata.csv', + 'md5': 'dfeb3348e79b719bf03c230d5d258839', + }, + 'train_val_test_split': { + 'url': 'https://zenodo.org/records/6810792/files/stratified_train_val_test_split.csv?download=1', + 'filename': 'stratified_train_val_test_split.csv', + 'md5': '745035835d835280aa0298a9dc1996d1', + }, + } + + def __init__( + self, + root: Path = 'data', + modalities: Sequence[str] = all_modalities, + split: str = 'train', + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize the WorldStrat dataset. + + Args: + root: Root directory where the dataset can be found. + modalities: Sequence of input modalities to load, choose from + 'hr_ps', 'hr_pan', 'hr_rgbn', 'lr_rgbn', 'l1c', 'l2a'. + split: The dataset split to load, choose from 'train', 'val', 'test'. + transforms: A function/transform that takes in a dictionary of tensors + and returns a transformed version. + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: if ``split`` or ``modalities``arguments are invalid + DatasetNotFoundError: If dataset is not found and *download* is False. + """ + assert all(modality in self.all_modalities for modality in modalities), ( + f'Invalid modality: {modalities}, please choose from {self.all_modalities}' + ) + assert split in self.valid_splits, ( + f'Invalid split: {split}, please choose from {self.valid_splits}' + ) + + self.root = root + self.modalities = modalities + self.split = split + self.transforms = transforms + self.download = download + self.checksum = checksum + + self._verify() + + self.file_path_df = pd.read_csv( + os.path.join( + self.root, self.file_info_dict['train_val_test_split']['filename'] + ) + ) + + self.file_path_df = self.file_path_df[ + self.file_path_df['split'] == self.split + ].reset_index(drop=True) + self.metadata_df = pd.read_csv( + os.path.join(self.root, self.file_info_dict['metadata']['filename']) + ) + self.metadata_df.rename(columns={'Unnamed: 0': 'tile_id'}, inplace=True) + + def __getitem__(self, idx: int) -> dict[str, Tensor]: + """Retrieve a sample from the dataset. + + Args: + idx: Index of the sample to retrieve. + + Returns: + Selected modalities of low and high resolution images and metadata. + """ + file_entry = self.file_path_df.iloc[idx] + aoi = file_entry['tile'] + data_dir = os.path.join(self.root, aoi) + + sample: dict[str, Tensor] = {} + + modality_loaders: dict[str, Callable[[], Tensor]] = { + 'l1c': lambda: self._load_sentinel_data(os.path.join(data_dir, 'L1C')), + 'l2a': lambda: self._load_sentinel_data(os.path.join(data_dir, 'L2A')), + 'lr_rgbn': lambda: self._load_tiff( + os.path.join(data_dir, f'{aoi}_rgbn.tiff') + ), + 'hr_ps': lambda: self._load_tiff(os.path.join(data_dir, f'{aoi}_ps.tiff')), + 'hr_pan': lambda: self._load_tiff( + os.path.join(data_dir, f'{aoi}_pan.tiff') + ), + 'hr_rgbn': lambda: torch.from_numpy( + np.array( + Image.open(os.path.join(data_dir, f'{aoi}_rgb.png')) + ).transpose(2, 0, 1) + ).float(), + } + + # Load only selected modalities + for modality in self.modalities: + sample[f'image_{modality}'] = modality_loaders[modality]() + + # Add metadata + metadata = self.metadata_df[self.metadata_df['tile_id'] == aoi].reset_index( + drop=True + ) + sample.update( + { + 'lon': metadata['lon'][0], + 'lat': metadata['lat'][0], + 'low_res_date': metadata['lowres_date'][0], + 'high_res_date': metadata['highres_date'][0], + } + ) + + return sample + + def _load_sentinel_data(self, data_dir: str) -> Tensor: + """Load Sentinel data for a given AOI in a data directory. + + Args: + data_dir: Directory containing the Sentinel data, in the dataset + this is either the L1C or L2A directory with time-series. + + Returns: + Loaded Sentinel data stacked as tensor of shape [T, C, H, W]. + """ + tiff_paths = glob( + os.path.join(data_dir, f'*{os.path.basename(data_dir)}_data.tiff'), + recursive=True, + ) + + # load and stack the data + data = [] + for tiff_path in tiff_paths: + data.append(self._load_tiff(tiff_path)) + + return torch.stack(data).float() + + def _load_tiff(self, tiff_path: str) -> Tensor: + """Load a tiff file as a tensor.""" + with rasterio.open(tiff_path) as src: + data = src.read() + tensor = array_to_tensor(data) + return tensor + + def __len__(self) -> int: + """Return the number of samples in the dataset.""" + return len(self.file_path_df) + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # check if directories are present + exists = [] + split_info_path = os.path.join( + self.root, self.file_info_dict['train_val_test_split']['filename'] + ) + if os.path.exists(split_info_path): + df = pd.read_csv(split_info_path) + df = df[df['split'] == self.split] + # check that all tiles are present + for tile in df['tile']: + exists.append(os.path.exists(os.path.join(self.root, tile))) + else: + exists.append(False) + + if all(exists): + return + + # check if downloaded files are present + exists = [] + for file in self.file_info_dict.values(): + path = os.path.join(self.root, file['filename']) + if os.path.exists(path): + if self.checksum: + md5 = file['md5'] + if not check_integrity(path, md5): + raise RuntimeError(f'Archive {file["filename"]} corrupted') + exists.append(True) + else: + exists.append(False) + + if all(exists): + # extract files + self._extract() + return + + if not self.download: + raise DatasetNotFoundError(self) + + # download + self._download() + + def _extract(self) -> None: + """Extract tar balls to root directory.""" + for file in self.file_info_dict.values(): + if 'tar.gz' in file['filename']: + extract_archive(os.path.join(self.root, file['filename']), self.root) + + def _download(self) -> None: + """Download the dataset and extract it.""" + for _, metadata in self.file_info_dict.items(): + if 'tar.gz' in metadata['filename']: + download_and_extract_archive( + metadata['url'], + self.root, + filename=metadata['filename'], + md5=metadata['md5'] if self.checksum else None, + ) + else: + download_url( + metadata['url'], + self.root, + filename=metadata['filename'], + md5=metadata['md5'] if self.checksum else None, + ) + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + suptitle: str | None = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + n_panels = len([k for k in sample.keys() if k.startswith('image_')]) + n_panels += 'prediction' in sample + + fig, axs = plt.subplots(1, n_panels, figsize=(5 * n_panels, 5), squeeze=False) + + for panel, modality in enumerate(self.modalities): + key = f'image_{modality}' + if key in sample: + img = sample[key].numpy() + + # Select and normalize image data + if modality in ['hr_ps', 'hr_pan']: + img = img[0, ...] + elif modality == 'hr_rgbn': + img = img[0:3, ...] + elif modality in ['l1c', 'l2a']: + img = img[0, [4, 3, 2], ...] + + # Apply percentile normalization + img = percentile_normalization(img) + + # Handle channel ordering + if img.ndim == 3: + img = img.transpose(1, 2, 0) + + axs[0, panel].imshow(img) + axs[0, panel].axis('off') + if show_titles: + axs[0, panel].set_title(self.modality_titles[modality]) + + if 'prediction' in sample: + pred = sample['prediction'].numpy().transpose(1, 2, 0) + if pred.shape[-1] == 4: + pred = pred[..., :3] + axs[0, -1].imshow(pred) + axs[0, -1].axis('off') + if show_titles: + axs[0, -1].set_title('Prediction') + + if suptitle: + fig.suptitle(suptitle) + + return fig