Skip to content

Commit 7a392ad

Browse files
jonGuti13xiexinch
andauthored
[Feature] add HSI-Drive dataset (#3365)
## Motivation The motivation is to add a hyperspectral dataset [HSI Drive 2.0](https://ipaccess.ehu.eus/HSI-Drive/) to the dataset registry which would be, as far as I know, the first hyperspectral database of mmsegmentation. This database has been presented in [HSI-Drive v2.0: More Data for New Challenges in Scene Understanding for Autonomous Driving](https://ieeexplore.ieee.org/document/10371793) and the initival v1 was presented in [HSI-Drive: A Dataset for the Research of Hyperspectral Image Processing Applied to Autonomous Driving Systems](https://ieeexplore.ieee.org/document/9575298) ## Modification I have created/modified the following aspects: - READMEs: `README.md` and `README_zh-CN.md` (sorry if translation is not accurate). - Example project: `projects/hsidrive20_dataset` has been created and filled for users to know how to work with this database. - Documentation: `docs/en/user_guides/2_dataset_prepare.md` and `docs/zh_cn/user_guides/2_dataset_prepare.md` (sorry if translation is not accurate) have been updated for users to know how to download and configure the dataset. - Database related files: `mmseg/datasets/__init__.py`, `mmseg/datasets/hsi_drive.py` and `configs/_base_/datasets/hsi_drive.py` where the dataset is described and also prepared for training/validation/test. - Transforms related files: `mmsegmentation/mmseg/datasets/transforms/loading.py` to *include support for loading images from .npy files* such as the hyperspectral images of this dataset. - Training config with well-known neural network: `configs/unet/unet-s5-d16_fcn_4xb4-160k_hsidrive-192x384.py` for people to train a standard neural network with this dataset. - Tests: added necessary files under `tests/data/pseudo_hsidrive20_dataset`. **Important:** I have also modified `.pre-commit-config.yaml` to ignore HSI error in codespell. ## BC-breaking (Optional) No. ## Use cases (Optional) A train example has been added under `projects/hsidrive20_dataset` and documentation has been updated as it is explained in Modification section. ## Checklist 1. Pre-commit or other linting tools are used to fix the potential lint issues. 2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness. 3. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D. 4. The documentation has been modified accordingly, like docstring or example tutorials. Regarding 1. I don't know how to solve this problem. Could you help me, please? This causes 2 checks not to be successful. --------- Co-authored-by: xiexinch <[email protected]>
1 parent 6a709be commit 7a392ad

File tree

23 files changed

+576
-2
lines changed

23 files changed

+576
-2
lines changed

Diff for: .pre-commit-config.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ repos:
3737
rev: v2.2.1
3838
hooks:
3939
- id: codespell
40+
args: [--ignore-words-list=hsi]
4041
- repo: https://github.com/myint/docformatter
4142
rev: v1.3.1
4243
hooks:

Diff for: README.md

+1
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
339339
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#levir-cd">LEVIR-CD</a></li>
340340
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#bdd100K">BDD100K</a></li>
341341
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#nyu">NYU</a></li>
342+
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#hsi-drive-2.0">HSIDrive20</a></li>
342343
</ul>
343344
</td>
344345
<td>

Diff for: README_zh-CN.md

+1
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ MMSegmentation v1.x 在 0.x 版本的基础上有了显著的提升,提供了
328328
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#levir-cd">LEVIR-CD</a></li>
329329
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/user_guides/2_dataset_prepare.md#bdd100K">BDD100K</a></li>
330330
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#nyu">NYU</a></li>
331+
<li><a href="https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#hsi-drive-2.0">HSIDrive20</a></li>
331332
</ul>
332333
</td>
333334
<td>

Diff for: configs/_base_/datasets/hsi_drive.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
train_pipeline = [
2+
dict(type='LoadImageFromNpyFile'),
3+
dict(type='LoadAnnotations'),
4+
dict(type='RandomCrop', crop_size=(192, 384)),
5+
dict(type='PackSegInputs')
6+
]
7+
8+
test_pipeline = [
9+
dict(type='LoadImageFromNpyFile'),
10+
dict(type='RandomCrop', crop_size=(192, 384)),
11+
dict(type='LoadAnnotations'),
12+
dict(type='PackSegInputs')
13+
]
14+
15+
train_dataloader = dict(
16+
batch_size=4,
17+
num_workers=1,
18+
persistent_workers=True,
19+
sampler=dict(type='InfiniteSampler', shuffle=True),
20+
dataset=dict(
21+
type='HSIDrive20Dataset',
22+
data_root='data/HSIDrive20',
23+
data_prefix=dict(
24+
img_path='images/training', seg_map_path='annotations/training'),
25+
pipeline=train_pipeline))
26+
27+
val_dataloader = dict(
28+
batch_size=1,
29+
num_workers=1,
30+
persistent_workers=True,
31+
sampler=dict(type='DefaultSampler', shuffle=False),
32+
dataset=dict(
33+
type='HSIDrive20Dataset',
34+
data_root='data/HSIDrive20',
35+
data_prefix=dict(
36+
img_path='images/validation',
37+
seg_map_path='annotations/validation'),
38+
pipeline=test_pipeline))
39+
40+
test_dataloader = dict(
41+
batch_size=1,
42+
num_workers=1,
43+
persistent_workers=True,
44+
sampler=dict(type='DefaultSampler', shuffle=False),
45+
dataset=dict(
46+
type='HSIDrive20Dataset',
47+
data_root='data/HSIDrive20',
48+
data_prefix=dict(
49+
img_path='images/test', seg_map_path='annotations/test'),
50+
pipeline=test_pipeline))
51+
52+
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'], ignore_index=0)
53+
test_evaluator = val_evaluator
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
_base_ = [
2+
'../_base_/models/fcn_unet_s5-d16.py', '../_base_/datasets/hsi_drive.py',
3+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
4+
]
5+
crop_size = (192, 384)
6+
data_preprocessor = dict(
7+
type='SegDataPreProcessor',
8+
size=crop_size,
9+
mean=None,
10+
std=None,
11+
bgr_to_rgb=None,
12+
pad_val=0,
13+
seg_pad_val=255)
14+
15+
model = dict(
16+
data_preprocessor=data_preprocessor,
17+
backbone=dict(in_channels=25),
18+
decode_head=dict(
19+
ignore_index=0,
20+
num_classes=11,
21+
loss_decode=dict(
22+
type='CrossEntropyLoss',
23+
use_sigmoid=False,
24+
loss_weight=1.0,
25+
avg_non_ignore=True)),
26+
auxiliary_head=dict(
27+
ignore_index=0,
28+
num_classes=11,
29+
loss_decode=dict(
30+
type='CrossEntropyLoss',
31+
use_sigmoid=False,
32+
loss_weight=1.0,
33+
avg_non_ignore=True)),
34+
# model training and testing settings
35+
train_cfg=dict(),
36+
test_cfg=dict(mode='whole'))

Diff for: docs/en/user_guides/2_dataset_prepare.md

+52
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,15 @@ mmsegmentation
205205
│ │ ├── annotations
206206
│ │ │ ├── train
207207
│ │ │ ├── test
208+
│ ├── HSIDrive20
209+
│ │ ├── images
210+
│ │ │ ├── train
211+
│ │ │ ├── validation
212+
│ │ │ ├── test
213+
│ │ ├── annotations
214+
│ │ │ ├── train
215+
│ │ │ ├── validation
216+
│ │ │ ├── test
208217
```
209218

210219
## Download dataset via MIM
@@ -752,3 +761,46 @@ mmsegmentation
752761
```bash
753762
python tools/dataset_converters/nyu.py nyu.zip
754763
```
764+
765+
## HSI Drive 2.0
766+
767+
- You could download HSI Drive 2.0 dataset from [here](https://ipaccess.ehu.eus/HSI-Drive/#download) after just sending an email to [email protected] with the subject "download HSI-Drive". You will receive a password to uncompress the files.
768+
769+
- After download, unzip by the following instructions:
770+
771+
```bash
772+
7z x -p"password" ./HSI_Drive_v2_0_Phyton.zip
773+
774+
mv ./HSIDrive20 path_to_mmsegmentation/data
775+
mv ./HSI_Drive_v2_0_release_notes_Python_version.md path_to_mmsegmentation/data
776+
mv ./image_numbering.pdf path_to_mmsegmentation/data
777+
```
778+
779+
- After unzip, you get
780+
781+
```none
782+
mmsegmentation
783+
├── mmseg
784+
├── tools
785+
├── configs
786+
├── data
787+
│ ├── HSIDrive20
788+
│ │ ├── images
789+
│ │ │ ├── training
790+
│ │ │ ├── validation
791+
│ │ │ ├── test
792+
│ │ ├── annotations
793+
│ │ │ ├── training
794+
│ │ │ ├── validation
795+
│ │ │ ├── test
796+
│ │ ├── images_MF
797+
│ │ │ ├── training
798+
│ │ │ ├── validation
799+
│ │ │ ├── test
800+
│ │ ├── RGB
801+
│ │ ├── training_filenames.txt
802+
│ │ ├── validation_filenames.txt
803+
│ │ ├── test_filenames.txt
804+
│ ├── HSI_Drive_v2_0_release_notes_Python_version.md
805+
│ ├── image_numbering.pdf
806+
```

Diff for: docs/zh_cn/user_guides/2_dataset_prepare.md

+52
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,15 @@ mmsegmentation
205205
│ │ ├── annotations
206206
│ │ │ ├── train
207207
│ │ │ ├── test
208+
│ ├── HSIDrive20
209+
│ │ ├── images
210+
│ │ │ ├── train
211+
│ │ │ ├── validation
212+
│ │ │ ├── test
213+
│ │ ├── annotations
214+
│ │ │ ├── train
215+
│ │ │ ├── validation
216+
│ │ │ ├── test
208217
```
209218

210219
## 用 MIM 下载数据集
@@ -748,3 +757,46 @@ mmsegmentation
748757
```bash
749758
python tools/dataset_converters/nyu.py nyu.zip
750759
```
760+
761+
## HSI Drive 2.0
762+
763+
- 您可以从以下位置下载 HSI Drive 2.0 数据集 [here](https://ipaccess.ehu.eus/HSI-Drive/#download) 刚刚向 [email protected] 发送主题为“下载 HSI-Drive”的电子邮件后 您将收到解压缩文件的密码.
764+
765+
- 下载后,按照以下说明解压:
766+
767+
```bash
768+
7z x -p"password" ./HSI_Drive_v2_0_Phyton.zip
769+
770+
mv ./HSIDrive20 path_to_mmsegmentation/data
771+
mv ./HSI_Drive_v2_0_release_notes_Python_version.md path_to_mmsegmentation/data
772+
mv ./image_numbering.pdf path_to_mmsegmentation/data
773+
```
774+
775+
- 解压后得到:
776+
777+
```none
778+
mmsegmentation
779+
├── mmseg
780+
├── tools
781+
├── configs
782+
├── data
783+
│ ├── HSIDrive20
784+
│ │ ├── images
785+
│ │ │ ├── training
786+
│ │ │ ├── validation
787+
│ │ │ ├── test
788+
│ │ ├── annotations
789+
│ │ │ ├── training
790+
│ │ │ ├── validation
791+
│ │ │ ├── test
792+
│ │ ├── images_MF
793+
│ │ │ ├── training
794+
│ │ │ ├── validation
795+
│ │ │ ├── test
796+
│ │ ├── RGB
797+
│ │ ├── training_filenames.txt
798+
│ │ ├── validation_filenames.txt
799+
│ │ ├── test_filenames.txt
800+
│ ├── HSI_Drive_v2_0_release_notes_Python_version.md
801+
│ ├── image_numbering.pdf
802+
```

Diff for: mmseg/datasets/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .drive import DRIVEDataset
1313
from .dsdl import DSDLSegDataset
1414
from .hrf import HRFDataset
15+
from .hsi_drive import HSIDrive20Dataset
1516
from .isaid import iSAIDDataset
1617
from .isprs import ISPRSDataset
1718
from .levir import LEVIRCDDataset
@@ -60,5 +61,5 @@
6061
'MapillaryDataset_v2', 'Albu', 'LEVIRCDDataset',
6162
'LoadMultipleRSImageFromFile', 'LoadSingleRSImageFromFile',
6263
'ConcatCDInput', 'BaseCDDataset', 'DSDLSegDataset', 'BDD100KDataset',
63-
'NYUDataset'
64+
'NYUDataset', 'HSIDrive20Dataset'
6465
]

Diff for: mmseg/datasets/hsi_drive.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from mmseg.datasets import BaseSegDataset
3+
from mmseg.registry import DATASETS
4+
5+
classes_exp = ('unlabelled', 'road', 'road marks', 'vegetation',
6+
'painted metal', 'sky', 'concrete', 'pedestrian', 'water',
7+
'unpainted metal', 'glass')
8+
palette_exp = [[0, 0, 0], [77, 77, 77], [255, 255, 255], [0, 255, 0],
9+
[255, 0, 0], [0, 0, 255], [102, 51, 0], [255, 255, 0],
10+
[0, 207, 250], [255, 166, 0], [0, 204, 204]]
11+
12+
13+
@DATASETS.register_module()
14+
class HSIDrive20Dataset(BaseSegDataset):
15+
"""HSI-Drive v2.0 (https://ieeexplore.ieee.org/document/10371793), the
16+
updated version of HSI-Drive
17+
(https://ieeexplore.ieee.org/document/9575298), is a structured dataset for
18+
the research and development of automated driving systems (ADS) supported
19+
by hyperspectral imaging (HSI). It contains per-pixel manually annotated
20+
images selected from videos recorded in real driving conditions and has
21+
been organized according to four parameters: season, daytime, road type,
22+
and weather conditions.
23+
24+
The video sequences have been captured with a small-size 25-band VNIR
25+
(Visible-NearlnfraRed) snapshot hyperspectral camera mounted on a driving
26+
automobile. As a consequence, you need to modify the in_channels parameter
27+
of your model from 3 (RGB images) to 25 (HSI images) as it is done in
28+
configs/unet/unet-s5-d16_fcn_4xb4-160k_hsidrive-192x384.py
29+
30+
Apart from the abovementioned articles, additional information is provided
31+
in the website (https://ipaccess.ehu.eus/HSI-Drive/) from where you can
32+
download the dataset and also visualize some examples of segmented videos.
33+
"""
34+
35+
METAINFO = dict(classes=classes_exp, palette=palette_exp)
36+
37+
def __init__(self,
38+
img_suffix='.npy',
39+
seg_map_suffix='.png',
40+
**kwargs) -> None:
41+
super().__init__(
42+
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)

Diff for: mmseg/datasets/transforms/loading.py

+67
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import warnings
3+
from pathlib import Path
34
from typing import Dict, Optional, Union
45

56
import mmcv
@@ -702,3 +703,69 @@ def __repr__(self):
702703
f'to_float32={self.to_float32}, '
703704
f'backend_args={self.backend_args})')
704705
return repr_str
706+
707+
708+
@TRANSFORMS.register_module()
709+
class LoadImageFromNpyFile(LoadImageFromFile):
710+
"""Load an image from ``results['img_path']``.
711+
712+
Required Keys:
713+
714+
- img_path
715+
716+
Modified Keys:
717+
718+
- img
719+
- img_shape
720+
- ori_shape
721+
722+
Args:
723+
to_float32 (bool): Whether to convert the loaded image to a float32
724+
numpy array. If set to False, the loaded image is an uint8 array.
725+
Defaults to False.
726+
"""
727+
728+
def transform(self, results: dict) -> Optional[dict]:
729+
"""Functions to load image.
730+
731+
Args:
732+
results (dict): Result dict from
733+
:class:`mmengine.dataset.BaseDataset`.
734+
735+
Returns:
736+
dict: The dict contains loaded image and meta information.
737+
"""
738+
739+
filename = results['img_path']
740+
741+
try:
742+
if Path(filename).suffix in ['.npy', '.npz']:
743+
img = np.load(filename)
744+
else:
745+
if self.file_client_args is not None:
746+
file_client = fileio.FileClient.infer_client(
747+
self.file_client_args, filename)
748+
img_bytes = file_client.get(filename)
749+
else:
750+
img_bytes = fileio.get(
751+
filename, backend_args=self.backend_args)
752+
img = mmcv.imfrombytes(
753+
img_bytes,
754+
flag=self.color_type,
755+
backend=self.imdecode_backend)
756+
except Exception as e:
757+
if self.ignore_empty:
758+
return None
759+
else:
760+
raise e
761+
762+
# in some cases, images are not read successfully, the img would be
763+
# `None`, refer to https://github.com/open-mmlab/mmpretrain/issues/1427
764+
assert img is not None, f'failed to load image: {filename}'
765+
if self.to_float32:
766+
img = img.astype(np.float32)
767+
768+
results['img'] = img
769+
results['img_shape'] = img.shape[:2]
770+
results['ori_shape'] = img.shape[:2]
771+
return results

0 commit comments

Comments
 (0)