6
6
import glob
7
7
import os
8
8
import random
9
+ import re
9
10
from collections .abc import Callable
10
- from typing import ClassVar , TypedDict
11
+ from typing import Any , ClassVar , TypedDict
11
12
12
13
import matplotlib .pyplot as plt
13
14
import numpy as np
18
19
19
20
from .errors import DatasetNotFoundError
20
21
from .geo import NonGeoDataset
21
- from .utils import Path , check_integrity , download_url , extract_archive
22
+ from .sentinel import Sentinel , Sentinel1 , Sentinel2
23
+ from .utils import (
24
+ BoundingBox ,
25
+ Path ,
26
+ check_integrity ,
27
+ disambiguate_timestamp ,
28
+ download_url ,
29
+ extract_archive ,
30
+ )
22
31
23
32
24
33
class SSL4EO (NonGeoDataset ):
@@ -321,7 +330,7 @@ def plot(
321
330
return fig
322
331
323
332
324
- class SSL4EOS12 (NonGeoDataset ):
333
+ class SSL4EOS12 (SSL4EO ):
325
334
"""SSL4EO-S12 dataset.
326
335
327
336
`Sentinel-1/2 <https://github.com/zhu-xlab/SSL4EO-S12>`_ version of SSL4EO.
@@ -362,6 +371,7 @@ class _Metadata(TypedDict):
362
371
'filename' : 's1.tar.gz' ,
363
372
'md5' : '51ee23b33eb0a2f920bda25225072f3a' ,
364
373
'bands' : ['VV' , 'VH' ],
374
+ 'filename_regex' : r'^S1[AB]_(?P<mode>SM|IW|EW|WV)_.{9}_(?P<date>\d{8}T\d{6})' ,
365
375
},
366
376
's2c' : {
367
377
'filename' : 's2_l1c.tar.gz' ,
@@ -381,6 +391,7 @@ class _Metadata(TypedDict):
381
391
'B11' ,
382
392
'B12' ,
383
393
],
394
+ 'filename_regex' : r'^(?P<date>\d{8}T\d{6})' ,
384
395
},
385
396
's2a' : {
386
397
'filename' : 's2_l2a.tar.gz' ,
@@ -399,6 +410,7 @@ class _Metadata(TypedDict):
399
410
'B11' ,
400
411
'B12' ,
401
412
],
413
+ 'filename_regex' : r'^(?P<date>\d{8}T\d{6})' ,
402
414
},
403
415
}
404
416
@@ -439,7 +451,7 @@ def __init__(
439
451
440
452
self ._verify ()
441
453
442
- def __getitem__ (self , index : int ) -> dict [str , Tensor ]:
454
+ def __getitem__ (self , index : int ) -> dict [str , Any ]:
443
455
"""Return an index within the dataset.
444
456
445
457
Args:
@@ -451,17 +463,37 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
451
463
root = os .path .join (self .root , self .split , f'{ index :07} ' )
452
464
subdirs = os .listdir (root )
453
465
subdirs = random .sample (subdirs , self .seasons )
466
+ filename_regex = self .metadata [self .split ]['filename_regex' ]
454
467
455
468
images = []
469
+ bounds = []
470
+ wavelengths = []
456
471
for subdir in subdirs :
457
472
directory = os .path .join (root , subdir )
458
- for band in self .bands :
459
- filename = os .path .join (directory , f'{ band } .tif' )
460
- with rasterio .open (filename ) as f :
461
- image = f .read (out_shape = (1 , self .size , self .size ))
462
- images .append (torch .from_numpy (image .astype (np .float32 )))
463
-
464
- sample = {'image' : torch .cat (images )}
473
+ if match := re .match (filename_regex , subdir ):
474
+ date_str = match .group ('date' )
475
+ mint , maxt = disambiguate_timestamp (date_str , Sentinel .date_format )
476
+ for band in self .bands :
477
+ match self .split :
478
+ case 's1' :
479
+ wavelengths .append (Sentinel1 .wavelength )
480
+ case 's2c' | 's2a' :
481
+ wavelengths .append (Sentinel2 .wavelengths [band ])
482
+
483
+ filename = os .path .join (directory , f'{ band } .tif' )
484
+ with rasterio .open (filename ) as f :
485
+ minx , maxx = f .bounds .left , f .bounds .right
486
+ miny , maxy = f .bounds .bottom , f .bounds .top
487
+ image = f .read (out_shape = (1 , self .size , self .size ))
488
+ images .append (torch .from_numpy (image .astype (np .float32 )))
489
+ bounds .append (BoundingBox (minx , maxx , miny , maxy , mint , maxt ))
490
+
491
+ sample = {
492
+ 'image' : torch .cat (images ),
493
+ 'bounds' : bounds ,
494
+ 'wavelengths' : wavelengths ,
495
+ 'gsd' : 10 ,
496
+ }
465
497
466
498
if self .transforms is not None :
467
499
sample = self .transforms (sample )
0 commit comments