Skip to content

Commit 551b7ee

Browse files
committed
SSL4EO-S12: add additional metadata
1 parent 8a3b60d commit 551b7ee

File tree

2 files changed

+79
-19
lines changed

2 files changed

+79
-19
lines changed

torchgeo/datasets/sentinel.py

+36-8
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""Sentinel datasets."""
55

66
from collections.abc import Callable, Iterable, Sequence
7-
from typing import Any
7+
from typing import Any, ClassVar
88

99
import matplotlib.pyplot as plt
1010
import torch
@@ -28,6 +28,9 @@ class Sentinel(RasterDataset):
2828
* https://asf.alaska.edu/datasets/daac/sentinel-1/
2929
"""
3030

31+
date_format = '%Y%m%dT%H%M%S'
32+
separate_files = True
33+
3134

3235
class Sentinel1(Sentinel):
3336
r"""Sentinel-1 dataset.
@@ -136,9 +139,11 @@ class Sentinel1(Sentinel):
136139
_(?P<band>[VH]{2})
137140
\.
138141
"""
139-
date_format = '%Y%m%dT%H%M%S'
142+
143+
# https://sentiwiki.copernicus.eu/web/s1-mission
140144
all_bands = ('HH', 'HV', 'VV', 'VH')
141-
separate_files = True
145+
# Central wavelength (μm)
146+
wavelength = 55500
142147

143148
def __init__(
144149
self,
@@ -274,10 +279,9 @@ class Sentinel2(Sentinel):
274279
(?:_(?P<resolution>{}m))?
275280
\..*$
276281
"""
277-
date_format = '%Y%m%dT%H%M%S'
278282

279-
# https://gisgeography.com/sentinel-2-bands-combinations/
280-
all_bands: tuple[str, ...] = (
283+
# https://sentiwiki.copernicus.eu/web/s2-mission
284+
all_bands = (
281285
'B01',
282286
'B02',
283287
'B03',
@@ -293,8 +297,32 @@ class Sentinel2(Sentinel):
293297
'B12',
294298
)
295299
rgb_bands = ('B04', 'B03', 'B02')
296-
297-
separate_files = True
300+
# Central wavelength (μm)
301+
wavelengths: ClassVar[dict[str, float]] = {
302+
'B01': 0.4427,
303+
'B02': 0.4927,
304+
'B03': 0.5598,
305+
'B04': 0.6646,
306+
'B05': 0.7041,
307+
'B06': 0.7405,
308+
'B07': 0.7828,
309+
'B08': 0.8328,
310+
'B8A': 0.8647,
311+
'B09': 0.9451,
312+
'B10': 1.3735,
313+
'B11': 1.6137,
314+
'B12': 2.2024,
315+
# For compatibility with other dataset naming conventions
316+
'B1': 0.4427,
317+
'B2': 0.4927,
318+
'B3': 0.5598,
319+
'B4': 0.6646,
320+
'B5': 0.7041,
321+
'B6': 0.7405,
322+
'B7': 0.7828,
323+
'B8': 0.8328,
324+
'B9': 0.9451,
325+
}
298326

299327
def __init__(
300328
self,

torchgeo/datasets/ssl4eo.py

+43-11
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
import glob
77
import os
88
import random
9+
import re
910
from collections.abc import Callable
10-
from typing import ClassVar, TypedDict
11+
from typing import Any, ClassVar, TypedDict
1112

1213
import matplotlib.pyplot as plt
1314
import numpy as np
@@ -18,7 +19,15 @@
1819

1920
from .errors import DatasetNotFoundError
2021
from .geo import NonGeoDataset
21-
from .utils import Path, check_integrity, download_url, extract_archive
22+
from .sentinel import Sentinel, Sentinel1, Sentinel2
23+
from .utils import (
24+
BoundingBox,
25+
Path,
26+
check_integrity,
27+
disambiguate_timestamp,
28+
download_url,
29+
extract_archive,
30+
)
2231

2332

2433
class SSL4EO(NonGeoDataset):
@@ -321,7 +330,7 @@ def plot(
321330
return fig
322331

323332

324-
class SSL4EOS12(NonGeoDataset):
333+
class SSL4EOS12(SSL4EO):
325334
"""SSL4EO-S12 dataset.
326335
327336
`Sentinel-1/2 <https://github.com/zhu-xlab/SSL4EO-S12>`_ version of SSL4EO.
@@ -362,6 +371,7 @@ class _Metadata(TypedDict):
362371
'filename': 's1.tar.gz',
363372
'md5': '51ee23b33eb0a2f920bda25225072f3a',
364373
'bands': ['VV', 'VH'],
374+
'filename_regex': r'^S1[AB]_(?P<mode>SM|IW|EW|WV)_.{9}_(?P<date>\d{8}T\d{6})',
365375
},
366376
's2c': {
367377
'filename': 's2_l1c.tar.gz',
@@ -381,6 +391,7 @@ class _Metadata(TypedDict):
381391
'B11',
382392
'B12',
383393
],
394+
'filename_regex': r'^(?P<date>\d{8}T\d{6})',
384395
},
385396
's2a': {
386397
'filename': 's2_l2a.tar.gz',
@@ -399,6 +410,7 @@ class _Metadata(TypedDict):
399410
'B11',
400411
'B12',
401412
],
413+
'filename_regex': r'^(?P<date>\d{8}T\d{6})',
402414
},
403415
}
404416

@@ -439,7 +451,7 @@ def __init__(
439451

440452
self._verify()
441453

442-
def __getitem__(self, index: int) -> dict[str, Tensor]:
454+
def __getitem__(self, index: int) -> dict[str, Any]:
443455
"""Return an index within the dataset.
444456
445457
Args:
@@ -451,17 +463,37 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
451463
root = os.path.join(self.root, self.split, f'{index:07}')
452464
subdirs = os.listdir(root)
453465
subdirs = random.sample(subdirs, self.seasons)
466+
filename_regex = self.metadata[self.split]['filename_regex']
454467

455468
images = []
469+
bounds = []
470+
wavelengths = []
456471
for subdir in subdirs:
457472
directory = os.path.join(root, subdir)
458-
for band in self.bands:
459-
filename = os.path.join(directory, f'{band}.tif')
460-
with rasterio.open(filename) as f:
461-
image = f.read(out_shape=(1, self.size, self.size))
462-
images.append(torch.from_numpy(image.astype(np.float32)))
463-
464-
sample = {'image': torch.cat(images)}
473+
if match := re.match(filename_regex, subdir):
474+
date_str = match.group('date')
475+
mint, maxt = disambiguate_timestamp(date_str, Sentinel.date_format)
476+
for band in self.bands:
477+
match self.split:
478+
case 's1':
479+
wavelengths.append(Sentinel1.wavelength)
480+
case 's2c' | 's2a':
481+
wavelengths.append(Sentinel2.wavelengths[band])
482+
483+
filename = os.path.join(directory, f'{band}.tif')
484+
with rasterio.open(filename) as f:
485+
minx, maxx = f.bounds.left, f.bounds.right
486+
miny, maxy = f.bounds.bottom, f.bounds.top
487+
image = f.read(out_shape=(1, self.size, self.size))
488+
images.append(torch.from_numpy(image.astype(np.float32)))
489+
bounds.append(BoundingBox(minx, maxx, miny, maxy, mint, maxt))
490+
491+
sample = {
492+
'image': torch.cat(images),
493+
'bounds': bounds,
494+
'wavelengths': wavelengths,
495+
'gsd': 10,
496+
}
465497

466498
if self.transforms is not None:
467499
sample = self.transforms(sample)

0 commit comments

Comments
 (0)