diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index aa5942748a..337e90cdca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,6 +37,7 @@ repos: rev: v2.2.1 hooks: - id: codespell + args: [--ignore-words-list=hsi] - repo: https://github.com/myint/docformatter rev: v1.3.1 hooks: diff --git a/README.md b/README.md index fae3654f8e..79ecdb7111 100644 --- a/README.md +++ b/README.md @@ -339,6 +339,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
  • LEVIR-CD
  • BDD100K
  • NYU
  • +
  • HSIDrive20
  • diff --git a/README_zh-CN.md b/README_zh-CN.md index 7a3e3ada72..e047759b08 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -328,6 +328,7 @@ MMSegmentation v1.x 在 0.x 版本的基础上有了显著的提升,提供了
  • LEVIR-CD
  • BDD100K
  • NYU
  • +
  • HSIDrive20
  • diff --git a/configs/_base_/datasets/hsi_drive.py b/configs/_base_/datasets/hsi_drive.py new file mode 100644 index 0000000000..2d08e2d601 --- /dev/null +++ b/configs/_base_/datasets/hsi_drive.py @@ -0,0 +1,53 @@ +train_pipeline = [ + dict(type='LoadImageFromNpyFile'), + dict(type='LoadAnnotations'), + dict(type='RandomCrop', crop_size=(192, 384)), + dict(type='PackSegInputs') +] + +test_pipeline = [ + dict(type='LoadImageFromNpyFile'), + dict(type='RandomCrop', crop_size=(192, 384)), + dict(type='LoadAnnotations'), + dict(type='PackSegInputs') +] + +train_dataloader = dict( + batch_size=4, + num_workers=1, + persistent_workers=True, + sampler=dict(type='InfiniteSampler', shuffle=True), + dataset=dict( + type='HSIDrive20Dataset', + data_root='data/HSIDrive20', + data_prefix=dict( + img_path='images/training', seg_map_path='annotations/training'), + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type='HSIDrive20Dataset', + data_root='data/HSIDrive20', + data_prefix=dict( + img_path='images/validation', + seg_map_path='annotations/validation'), + pipeline=test_pipeline)) + +test_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type='HSIDrive20Dataset', + data_root='data/HSIDrive20', + data_prefix=dict( + img_path='images/test', seg_map_path='annotations/test'), + pipeline=test_pipeline)) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'], ignore_index=0) +test_evaluator = val_evaluator diff --git a/configs/unet/unet-s5-d16_fcn_4xb4-160k_hsidrive-192x384.py b/configs/unet/unet-s5-d16_fcn_4xb4-160k_hsidrive-192x384.py new file mode 100644 index 0000000000..a5768ba148 --- /dev/null +++ b/configs/unet/unet-s5-d16_fcn_4xb4-160k_hsidrive-192x384.py @@ -0,0 +1,36 @@ +_base_ = [ + '../_base_/models/fcn_unet_s5-d16.py', '../_base_/datasets/hsi_drive.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' +] +crop_size = (192, 384) +data_preprocessor = dict( + type='SegDataPreProcessor', + size=crop_size, + mean=None, + std=None, + bgr_to_rgb=None, + pad_val=0, + seg_pad_val=255) + +model = dict( + data_preprocessor=data_preprocessor, + backbone=dict(in_channels=25), + decode_head=dict( + ignore_index=0, + num_classes=11, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0, + avg_non_ignore=True)), + auxiliary_head=dict( + ignore_index=0, + num_classes=11, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0, + avg_non_ignore=True)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/docs/en/user_guides/2_dataset_prepare.md b/docs/en/user_guides/2_dataset_prepare.md index 2816a51f0d..3f94a94289 100644 --- a/docs/en/user_guides/2_dataset_prepare.md +++ b/docs/en/user_guides/2_dataset_prepare.md @@ -205,6 +205,15 @@ mmsegmentation │ │ ├── annotations │ │ │ ├── train │ │ │ ├── test +│ ├── HSIDrive20 +│ │ ├── images +│ │ │ ├── train +│ │ │ ├── validation +│ │ │ ├── test +│ │ ├── annotations +│ │ │ ├── train +│ │ │ ├── validation +│ │ │ ├── test ``` ## Download dataset via MIM @@ -752,3 +761,46 @@ mmsegmentation ```bash python tools/dataset_converters/nyu.py nyu.zip ``` + +## HSI Drive 2.0 + +- You could download HSI Drive 2.0 dataset from [here](https://ipaccess.ehu.eus/HSI-Drive/#download) after just sending an email to gded@ehu.eus with the subject "download HSI-Drive". You will receive a password to uncompress the files. + +- After download, unzip by the following instructions: + + ```bash + 7z x -p"password" ./HSI_Drive_v2_0_Phyton.zip + + mv ./HSIDrive20 path_to_mmsegmentation/data + mv ./HSI_Drive_v2_0_release_notes_Python_version.md path_to_mmsegmentation/data + mv ./image_numbering.pdf path_to_mmsegmentation/data + ``` + +- After unzip, you get + +```none +mmsegmentation +├── mmseg +├── tools +├── configs +├── data +│ ├── HSIDrive20 +│ │ ├── images +│ │ │ ├── training +│ │ │ ├── validation +│ │ │ ├── test +│ │ ├── annotations +│ │ │ ├── training +│ │ │ ├── validation +│ │ │ ├── test +│ │ ├── images_MF +│ │ │ ├── training +│ │ │ ├── validation +│ │ │ ├── test +│ │ ├── RGB +│ │ ├── training_filenames.txt +│ │ ├── validation_filenames.txt +│ │ ├── test_filenames.txt +│ ├── HSI_Drive_v2_0_release_notes_Python_version.md +│ ├── image_numbering.pdf +``` diff --git a/docs/zh_cn/user_guides/2_dataset_prepare.md b/docs/zh_cn/user_guides/2_dataset_prepare.md index 5532624bef..e32303a0bd 100644 --- a/docs/zh_cn/user_guides/2_dataset_prepare.md +++ b/docs/zh_cn/user_guides/2_dataset_prepare.md @@ -205,6 +205,15 @@ mmsegmentation │ │ ├── annotations │ │ │ ├── train │ │ │ ├── test +│ ├── HSIDrive20 +│ │ ├── images +│ │ │ ├── train +│ │ │ ├── validation +│ │ │ ├── test +│ │ ├── annotations +│ │ │ ├── train +│ │ │ ├── validation +│ │ │ ├── test ``` ## 用 MIM 下载数据集 @@ -748,3 +757,46 @@ mmsegmentation ```bash python tools/dataset_converters/nyu.py nyu.zip ``` + +## HSI Drive 2.0 + +- 您可以从以下位置下载 HSI Drive 2.0 数据集 [here](https://ipaccess.ehu.eus/HSI-Drive/#download) 刚刚向 gded@ehu.eus 发送主题为“下载 HSI-Drive”的电子邮件后 您将收到解压缩文件的密码. + +- 下载后,按照以下说明解压: + + ```bash + 7z x -p"password" ./HSI_Drive_v2_0_Phyton.zip + + mv ./HSIDrive20 path_to_mmsegmentation/data + mv ./HSI_Drive_v2_0_release_notes_Python_version.md path_to_mmsegmentation/data + mv ./image_numbering.pdf path_to_mmsegmentation/data + ``` + +- 解压后得到: + +```none +mmsegmentation +├── mmseg +├── tools +├── configs +├── data +│ ├── HSIDrive20 +│ │ ├── images +│ │ │ ├── training +│ │ │ ├── validation +│ │ │ ├── test +│ │ ├── annotations +│ │ │ ├── training +│ │ │ ├── validation +│ │ │ ├── test +│ │ ├── images_MF +│ │ │ ├── training +│ │ │ ├── validation +│ │ │ ├── test +│ │ ├── RGB +│ │ ├── training_filenames.txt +│ │ ├── validation_filenames.txt +│ │ ├── test_filenames.txt +│ ├── HSI_Drive_v2_0_release_notes_Python_version.md +│ ├── image_numbering.pdf +``` diff --git a/mmseg/datasets/__init__.py b/mmseg/datasets/__init__.py index a2bdb63d01..f8ad750d76 100644 --- a/mmseg/datasets/__init__.py +++ b/mmseg/datasets/__init__.py @@ -12,6 +12,7 @@ from .drive import DRIVEDataset from .dsdl import DSDLSegDataset from .hrf import HRFDataset +from .hsi_drive import HSIDrive20Dataset from .isaid import iSAIDDataset from .isprs import ISPRSDataset from .levir import LEVIRCDDataset @@ -60,5 +61,5 @@ 'MapillaryDataset_v2', 'Albu', 'LEVIRCDDataset', 'LoadMultipleRSImageFromFile', 'LoadSingleRSImageFromFile', 'ConcatCDInput', 'BaseCDDataset', 'DSDLSegDataset', 'BDD100KDataset', - 'NYUDataset' + 'NYUDataset', 'HSIDrive20Dataset' ] diff --git a/mmseg/datasets/hsi_drive.py b/mmseg/datasets/hsi_drive.py new file mode 100644 index 0000000000..3d46a86629 --- /dev/null +++ b/mmseg/datasets/hsi_drive.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.datasets import BaseSegDataset +from mmseg.registry import DATASETS + +classes_exp = ('unlabelled', 'road', 'road marks', 'vegetation', + 'painted metal', 'sky', 'concrete', 'pedestrian', 'water', + 'unpainted metal', 'glass') +palette_exp = [[0, 0, 0], [77, 77, 77], [255, 255, 255], [0, 255, 0], + [255, 0, 0], [0, 0, 255], [102, 51, 0], [255, 255, 0], + [0, 207, 250], [255, 166, 0], [0, 204, 204]] + + +@DATASETS.register_module() +class HSIDrive20Dataset(BaseSegDataset): + """HSI-Drive v2.0 (https://ieeexplore.ieee.org/document/10371793), the + updated version of HSI-Drive + (https://ieeexplore.ieee.org/document/9575298), is a structured dataset for + the research and development of automated driving systems (ADS) supported + by hyperspectral imaging (HSI). It contains per-pixel manually annotated + images selected from videos recorded in real driving conditions and has + been organized according to four parameters: season, daytime, road type, + and weather conditions. + + The video sequences have been captured with a small-size 25-band VNIR + (Visible-NearlnfraRed) snapshot hyperspectral camera mounted on a driving + automobile. As a consequence, you need to modify the in_channels parameter + of your model from 3 (RGB images) to 25 (HSI images) as it is done in + configs/unet/unet-s5-d16_fcn_4xb4-160k_hsidrive-192x384.py + + Apart from the abovementioned articles, additional information is provided + in the website (https://ipaccess.ehu.eus/HSI-Drive/) from where you can + download the dataset and also visualize some examples of segmented videos. + """ + + METAINFO = dict(classes=classes_exp, palette=palette_exp) + + def __init__(self, + img_suffix='.npy', + seg_map_suffix='.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) diff --git a/mmseg/datasets/transforms/loading.py b/mmseg/datasets/transforms/loading.py index 438b5527f0..c28937e55e 100644 --- a/mmseg/datasets/transforms/loading.py +++ b/mmseg/datasets/transforms/loading.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings +from pathlib import Path from typing import Dict, Optional, Union import mmcv @@ -702,3 +703,69 @@ def __repr__(self): f'to_float32={self.to_float32}, ' f'backend_args={self.backend_args})') return repr_str + + +@TRANSFORMS.register_module() +class LoadImageFromNpyFile(LoadImageFromFile): + """Load an image from ``results['img_path']``. + + Required Keys: + + - img_path + + Modified Keys: + + - img + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + """ + + def transform(self, results: dict) -> Optional[dict]: + """Functions to load image. + + Args: + results (dict): Result dict from + :class:`mmengine.dataset.BaseDataset`. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + filename = results['img_path'] + + try: + if Path(filename).suffix in ['.npy', '.npz']: + img = np.load(filename) + else: + if self.file_client_args is not None: + file_client = fileio.FileClient.infer_client( + self.file_client_args, filename) + img_bytes = file_client.get(filename) + else: + img_bytes = fileio.get( + filename, backend_args=self.backend_args) + img = mmcv.imfrombytes( + img_bytes, + flag=self.color_type, + backend=self.imdecode_backend) + except Exception as e: + if self.ignore_empty: + return None + else: + raise e + + # in some cases, images are not read successfully, the img would be + # `None`, refer to https://github.com/open-mmlab/mmpretrain/issues/1427 + assert img is not None, f'failed to load image: {filename}' + if self.to_float32: + img = img.astype(np.float32) + + results['img'] = img + results['img_shape'] = img.shape[:2] + results['ori_shape'] = img.shape[:2] + return results diff --git a/mmseg/utils/class_names.py b/mmseg/utils/class_names.py index 5ab35f99dc..644e955966 100644 --- a/mmseg/utils/class_names.py +++ b/mmseg/utils/class_names.py @@ -473,6 +473,21 @@ def bdd100k_palette(): [0, 0, 230], [119, 11, 32]] +def hsidrive_classes(): + """HSI Drive 2.0 class names for external use.""" + return [ + 'unlabelled', 'road', 'road marks', 'vegetation', 'painted metal', + 'sky', 'concrete', 'pedestrian', 'water', 'unpainted metal', 'glass' + ] + + +def hsidrive_palette(): + """HSI Drive 2.0 palette for external use.""" + return [[0, 0, 0], [77, 77, 77], [255, 255, 255], [0, 255, 0], [255, 0, 0], + [0, 0, 255], [102, 51, 0], [255, 255, 0], [0, 207, 250], + [255, 166, 0], [0, 204, 204]] + + dataset_aliases = { 'cityscapes': ['cityscapes'], 'ade': ['ade', 'ade20k'], @@ -491,7 +506,11 @@ def bdd100k_palette(): 'lip': ['LIP', 'lip'], 'mapillary_v1': ['mapillary_v1'], 'mapillary_v2': ['mapillary_v2'], - 'bdd100k': ['bdd100k'] + 'bdd100k': ['bdd100k'], + 'hsidrive': [ + 'hsidrive', 'HSIDrive', 'HSI-Drive', 'hsidrive20', 'HSIDrive20', + 'HSI-Drive20' + ] } diff --git a/projects/hsidrive20_dataset/README.md b/projects/hsidrive20_dataset/README.md new file mode 100644 index 0000000000..7ee6e984fd --- /dev/null +++ b/projects/hsidrive20_dataset/README.md @@ -0,0 +1,34 @@ +# HSI Drive 2.0 Dataset + +Support **`HSI Drive 2.0 Dataset`** + +## Description + +Author: Jon Gutierrez + +This project implements **`HSI Drive 2.0 Dataset`** + +### Dataset preparing + +Preparing `HSI Drive 2.0 Dataset` dataset following [HSI Drive 2.0 Dataset Preparing Guide](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#hsi-drive-2.0) + +```none +mmsegmentation/data +└── HSIDrive20 + ├── images + │ |── training [] + │ |── validation [] + │ |── test [] + └── labels + │ |── training [] + │ |── validation [] + │ |── test [] +``` + +### Training commands + +```bash +%cd mmsegmentation +!python tools/train.py projects/hsidrive20_dataset/configs/unet-s5-d16_fcn_4xb4-160k_hsidrive-208x400.py\ +--work-dir your_work_dir +``` diff --git a/projects/hsidrive20_dataset/configs/_base_/datasets/hsi_drive.py b/projects/hsidrive20_dataset/configs/_base_/datasets/hsi_drive.py new file mode 100644 index 0000000000..311426246c --- /dev/null +++ b/projects/hsidrive20_dataset/configs/_base_/datasets/hsi_drive.py @@ -0,0 +1,50 @@ +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='PackSegInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='PackSegInputs') +] + +train_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + sampler=dict(type='InfiniteSampler', shuffle=True), + dataset=dict( + type='HSIDrive20', + data_root='data/HSIDrive20', + data_prefix=dict( + img_path='images/training', seg_map_path='annotations/training'), + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type='HSIDrive20', + data_root='data/HSIDrive20', + data_prefix=dict( + img_path='images/validation', + seg_map_path='annotations/validation'), + pipeline=test_pipeline)) + +test_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type='HSIDrive20', + data_root='data/HSIDrive20', + data_prefix=dict( + img_path='images/test', seg_map_path='annotations/test'), + pipeline=test_pipeline)) + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'], ignore_index=0) +test_evaluator = val_evaluator diff --git a/projects/hsidrive20_dataset/configs/unet-s5-d16_fcn_4xb4-160k_hsidrive-192x384.py b/projects/hsidrive20_dataset/configs/unet-s5-d16_fcn_4xb4-160k_hsidrive-192x384.py new file mode 100644 index 0000000000..d5eab91747 --- /dev/null +++ b/projects/hsidrive20_dataset/configs/unet-s5-d16_fcn_4xb4-160k_hsidrive-192x384.py @@ -0,0 +1,58 @@ +_base_ = [ + '../../../configs/_base_/models/fcn_unet_s5-d16.py', + './_base_/datasets/hsi_drive.py', + '../../../configs/_base_/default_runtime.py', + '../../../configs/_base_/schedules/schedule_160k.py' +] + +custom_imports = dict( + imports=['projects.hsidrive20_dataset.mmseg.datasets.hsi_drive']) + +crop_size = (192, 384) +data_preprocessor = dict( + type='SegDataPreProcessor', + size=crop_size, + mean=None, + std=None, + bgr_to_rgb=None, + pad_val=0, + seg_pad_val=255) +model = dict( + data_preprocessor=data_preprocessor, + backbone=dict(in_channels=25), + decode_head=dict( + ignore_index=0, + num_classes=11, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0, + avg_non_ignore=True)), + auxiliary_head=dict( + ignore_index=0, + num_classes=11, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0, + avg_non_ignore=True)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomCrop', crop_size=crop_size), + dict(type='PackSegInputs') +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='RandomCrop', crop_size=crop_size), + dict(type='LoadAnnotations'), + dict(type='PackSegInputs') +] + +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) +test_dataloader = dict(dataset=dict(pipeline=test_pipeline)) diff --git a/projects/hsidrive20_dataset/docs/en/user_guides/2_dataset_prepare.md b/projects/hsidrive20_dataset/docs/en/user_guides/2_dataset_prepare.md new file mode 100644 index 0000000000..1d4ac8c99c --- /dev/null +++ b/projects/hsidrive20_dataset/docs/en/user_guides/2_dataset_prepare.md @@ -0,0 +1,42 @@ +## HSI Drive 2.0 + +- You could download HSI Drive 2.0 dataset from [here](https://ipaccess.ehu.eus/HSI-Drive/#download) after just sending an email to gded@ehu.eus with the subject "download HSI-Drive". You will receive a password to uncompress the files. + +- After download, unzip by the following instructions: + + ```bash + 7z x -p"password" ./HSI_Drive_v2_0_Phyton.zip + + mv ./HSIDrive20 path_to_mmsegmentation/data + mv ./HSI_Drive_v2_0_release_notes_Python_version.md path_to_mmsegmentation/data + mv ./image_numbering.pdf path_to_mmsegmentation/data + ``` + +- After unzip, you get + +```none +mmsegmentation +├── mmseg +├── tools +├── configs +├── data +│ ├── HSIDrive20 +│ │ ├── images +│ │ │ ├── training +│ │ │ ├── validation +│ │ │ ├── test +│ │ ├── annotations +│ │ │ ├── training +│ │ │ ├── validation +│ │ │ ├── test +│ │ ├── images_MF +│ │ │ ├── training +│ │ │ ├── validation +│ │ │ ├── test +│ │ ├── RGB +│ │ ├── training_filenames.txt +│ │ ├── validation_filenames.txt +│ │ ├── test_filenames.txt +│ ├── HSI_Drive_v2_0_release_notes_Python_version.md +│ ├── image_numbering.pdf +``` diff --git a/projects/hsidrive20_dataset/docs/zh_cn/user_guides/2_dataset_prepare.md b/projects/hsidrive20_dataset/docs/zh_cn/user_guides/2_dataset_prepare.md new file mode 100644 index 0000000000..dbf704a9cf --- /dev/null +++ b/projects/hsidrive20_dataset/docs/zh_cn/user_guides/2_dataset_prepare.md @@ -0,0 +1,42 @@ +## HSI Drive 2.0 + +- 您可以从以下位置下载 HSI Drive 2.0 数据集 [here](https://ipaccess.ehu.eus/HSI-Drive/#download) 刚刚向 gded@ehu.eus 发送主题为“下载 HSI-Drive”的电子邮件后 您将收到解压缩文件的密码. + +- 下载后,按照以下说明解压: + + ```bash + 7z x -p"password" ./HSI_Drive_v2_0_Phyton.zip + + mv ./HSIDrive20 path_to_mmsegmentation/data + mv ./HSI_Drive_v2_0_release_notes_Python_version.md path_to_mmsegmentation/data + mv ./image_numbering.pdf path_to_mmsegmentation/data + ``` + +- 解压后得到: + +```none +mmsegmentation +├── mmseg +├── tools +├── configs +├── data +│ ├── HSIDrive20 +│ │ ├── images +│ │ │ ├── training +│ │ │ ├── validation +│ │ │ ├── test +│ │ ├── annotations +│ │ │ ├── training +│ │ │ ├── validation +│ │ │ ├── test +│ │ ├── images_MF +│ │ │ ├── training +│ │ │ ├── validation +│ │ │ ├── test +│ │ ├── RGB +│ │ ├── training_filenames.txt +│ │ ├── validation_filenames.txt +│ │ ├── test_filenames.txt +│ ├── HSI_Drive_v2_0_release_notes_Python_version.md +│ ├── image_numbering.pdf +``` diff --git a/projects/hsidrive20_dataset/mmseg/datasets/hsi_drive.py b/projects/hsidrive20_dataset/mmseg/datasets/hsi_drive.py new file mode 100644 index 0000000000..f8589b037b --- /dev/null +++ b/projects/hsidrive20_dataset/mmseg/datasets/hsi_drive.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.datasets import BaseSegDataset + +# from mmseg.registry import DATASETS + +classes_exp = ('unlabelled', 'road', 'road marks', 'vegetation', + 'painted metal', 'sky', 'concrete', 'pedestrian', 'water', + 'unpainted metal', 'glass') +palette_exp = [[0, 0, 0], [77, 77, 77], [255, 255, 255], [0, 255, 0], + [255, 0, 0], [0, 0, 255], [102, 51, 0], [255, 255, 0], + [0, 207, 250], [255, 166, 0], [0, 204, 204]] + + +# @DATASETS.register_module() +class HSIDrive20Dataset(BaseSegDataset): + METAINFO = dict(classes=classes_exp, palette=palette_exp) + + def __init__(self, + img_suffix='.npy', + seg_map_suffix='.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) diff --git a/tests/data/pseudo_hsidrive20_dataset/annotations/test/nf1111_577_TC.png b/tests/data/pseudo_hsidrive20_dataset/annotations/test/nf1111_577_TC.png new file mode 100644 index 0000000000..b1301cb925 Binary files /dev/null and b/tests/data/pseudo_hsidrive20_dataset/annotations/test/nf1111_577_TC.png differ diff --git a/tests/data/pseudo_hsidrive20_dataset/annotations/test/nf1112_569_TC.png b/tests/data/pseudo_hsidrive20_dataset/annotations/test/nf1112_569_TC.png new file mode 100644 index 0000000000..4debaffcf8 Binary files /dev/null and b/tests/data/pseudo_hsidrive20_dataset/annotations/test/nf1112_569_TC.png differ diff --git a/tests/data/pseudo_hsidrive20_dataset/annotations/test/nf1113_557_TC.png b/tests/data/pseudo_hsidrive20_dataset/annotations/test/nf1113_557_TC.png new file mode 100644 index 0000000000..7e525b4f12 Binary files /dev/null and b/tests/data/pseudo_hsidrive20_dataset/annotations/test/nf1113_557_TC.png differ diff --git a/tests/data/pseudo_hsidrive20_dataset/images/test/nf1111_577_TC.npy b/tests/data/pseudo_hsidrive20_dataset/images/test/nf1111_577_TC.npy new file mode 100644 index 0000000000..850e4f0927 Binary files /dev/null and b/tests/data/pseudo_hsidrive20_dataset/images/test/nf1111_577_TC.npy differ diff --git a/tests/data/pseudo_hsidrive20_dataset/images/test/nf1112_569_TC.npy b/tests/data/pseudo_hsidrive20_dataset/images/test/nf1112_569_TC.npy new file mode 100644 index 0000000000..6482bbb7ba Binary files /dev/null and b/tests/data/pseudo_hsidrive20_dataset/images/test/nf1112_569_TC.npy differ diff --git a/tests/data/pseudo_hsidrive20_dataset/images/test/nf1113_557_TC.npy b/tests/data/pseudo_hsidrive20_dataset/images/test/nf1113_557_TC.npy new file mode 100644 index 0000000000..54f221afc7 Binary files /dev/null and b/tests/data/pseudo_hsidrive20_dataset/images/test/nf1113_557_TC.npy differ