Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSL4EO-S12: add additional metadata #2533

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions torchgeo/datasets/sentinel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""Sentinel datasets."""

from collections.abc import Callable, Iterable, Sequence
from typing import Any
from typing import Any, ClassVar

import matplotlib.pyplot as plt
import torch
Expand All @@ -28,6 +28,9 @@ class Sentinel(RasterDataset):
* https://asf.alaska.edu/datasets/daac/sentinel-1/
"""

date_format = '%Y%m%dT%H%M%S'
separate_files = True


class Sentinel1(Sentinel):
r"""Sentinel-1 dataset.
Expand Down Expand Up @@ -136,9 +139,11 @@ class Sentinel1(Sentinel):
_(?P<band>[VH]{2})
\.
"""
date_format = '%Y%m%dT%H%M%S'

# https://sentiwiki.copernicus.eu/web/s1-mission
all_bands = ('HH', 'HV', 'VV', 'VH')
separate_files = True
# Central wavelength (μm)
wavelength = 55500

def __init__(
self,
Expand Down Expand Up @@ -274,9 +279,8 @@ class Sentinel2(Sentinel):
(?:_(?P<resolution>{}m))?
\..*$
"""
date_format = '%Y%m%dT%H%M%S'

# https://gisgeography.com/sentinel-2-bands-combinations/
# https://sentiwiki.copernicus.eu/web/s2-mission
all_bands: tuple[str, ...] = (
'B01',
'B02',
Expand All @@ -293,8 +297,32 @@ class Sentinel2(Sentinel):
'B12',
)
rgb_bands = ('B04', 'B03', 'B02')

separate_files = True
# Central wavelength (μm)
wavelengths: ClassVar[dict[str, float]] = {
'B01': 0.4427,
'B02': 0.4927,
'B03': 0.5598,
'B04': 0.6646,
'B05': 0.7041,
'B06': 0.7405,
'B07': 0.7828,
'B08': 0.8328,
'B8A': 0.8647,
'B09': 0.9451,
'B10': 1.3735,
'B11': 1.6137,
'B12': 2.2024,
# For compatibility with other dataset naming conventions
'B1': 0.4427,
'B2': 0.4927,
'B3': 0.5598,
'B4': 0.6646,
'B5': 0.7041,
'B6': 0.7405,
'B7': 0.7828,
'B8': 0.8328,
'B9': 0.9451,
}

def __init__(
self,
Expand Down
56 changes: 47 additions & 9 deletions torchgeo/datasets/ssl4eo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import glob
import os
import random
import re
from collections.abc import Callable
from typing import ClassVar, TypedDict

Expand All @@ -18,7 +19,14 @@

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import Path, check_integrity, download_url, extract_archive
from .sentinel import Sentinel, Sentinel1, Sentinel2
from .utils import (
Path,
check_integrity,
disambiguate_timestamp,
download_url,
extract_archive,
)


class SSL4EO(NonGeoDataset):
Expand Down Expand Up @@ -321,7 +329,7 @@ def plot(
return fig


class SSL4EOS12(NonGeoDataset):
class SSL4EOS12(SSL4EO):
"""SSL4EO-S12 dataset.

`Sentinel-1/2 <https://github.com/zhu-xlab/SSL4EO-S12>`_ version of SSL4EO.
Expand Down Expand Up @@ -356,12 +364,14 @@ class _Metadata(TypedDict):
filename: str
md5: str
bands: list[str]
filename_regex: str

metadata: ClassVar[dict[str, _Metadata]] = {
's1': {
'filename': 's1.tar.gz',
'md5': '51ee23b33eb0a2f920bda25225072f3a',
'bands': ['VV', 'VH'],
'filename_regex': r'^.{16}_(?P<date>\d{8}T\d{6})',
},
's2c': {
'filename': 's2_l1c.tar.gz',
Expand All @@ -381,6 +391,7 @@ class _Metadata(TypedDict):
'B11',
'B12',
],
'filename_regex': r'^(?P<date>\d{8}T\d{6})',
},
's2a': {
'filename': 's2_l2a.tar.gz',
Expand All @@ -399,6 +410,7 @@ class _Metadata(TypedDict):
'B11',
'B12',
],
'filename_regex': r'^(?P<date>\d{8}T\d{6})',
},
}

Expand Down Expand Up @@ -451,17 +463,43 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
root = os.path.join(self.root, self.split, f'{index:07}')
subdirs = os.listdir(root)
subdirs = random.sample(subdirs, self.seasons)
filename_regex = self.metadata[self.split]['filename_regex']

images = []
xs = []
ys = []
ts = []
wavelengths: list[float] = []
for subdir in subdirs:
directory = os.path.join(root, subdir)
for band in self.bands:
filename = os.path.join(directory, f'{band}.tif')
with rasterio.open(filename) as f:
image = f.read(out_shape=(1, self.size, self.size))
images.append(torch.from_numpy(image.astype(np.float32)))

sample = {'image': torch.cat(images)}
if match := re.match(filename_regex, subdir):
date_str = match.group('date')
mint, maxt = disambiguate_timestamp(date_str, Sentinel.date_format)
for band in self.bands:
match self.split:
case 's1':
wavelengths.append(Sentinel1.wavelength)
case 's2c' | 's2a':
wavelengths.append(Sentinel2.wavelengths[band])

filename = os.path.join(directory, f'{band}.tif')
with rasterio.open(filename) as f:
minx, maxx = f.bounds.left, f.bounds.right
miny, maxy = f.bounds.bottom, f.bounds.top
image = f.read(out_shape=(1, self.size, self.size))
images.append(torch.from_numpy(image.astype(np.float32)))
xs.append((minx + maxx) / 2)
ys.append((miny + maxy) / 2)
ts.append((mint + maxt) / 2)

sample = {
'image': torch.cat(images),
'x': torch.tensor(xs),
'y': torch.tensor(ys),
't': torch.tensor(ts),
'wavelength': torch.tensor(wavelengths),
'res': torch.tensor(10),
}

if self.transforms is not None:
sample = self.transforms(sample)
Expand Down
Loading