From 2ea4784012d790ba012df6ce8dcebe766c8e73f5 Mon Sep 17 00:00:00 2001 From: masaaki Date: Sun, 25 Jun 2023 11:50:46 +0800 Subject: [PATCH] [Project] Medical semantic seg dataset: Kvasir seg (#2677) --- .../2d_image/endoscopy/kvasir_seg/README.md | 145 ++++++++++++++++++ ...moid}_1xb16-0.01-20k_kvasir-seg-512x512.py | 18 +++ ...net_1xb16-0.0001-20k_kvasir-seg-512x512.py | 17 ++ ...unet_1xb16-0.001-20k_kvasir-seg-512x512.py | 17 ++ ..._unet_1xb16-0.01-20k_kvasir-seg-512x512.py | 17 ++ .../kvasir_seg/configs/kvasir-seg_512x512.py | 42 +++++ .../kvasir_seg/datasets/kvasir-seg_dataset.py | 30 ++++ .../kvasir_seg/tools/prepare_dataset.py | 87 +++++++++++ 8 files changed, 373 insertions(+) create mode 100644 projects/medical/2d_image/endoscopy/kvasir_seg/README.md create mode 100644 projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_kvasir-seg-512x512.py create mode 100644 projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_kvasir-seg-512x512.py create mode 100644 projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_kvasir-seg-512x512.py create mode 100644 projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_kvasir-seg-512x512.py create mode 100644 projects/medical/2d_image/endoscopy/kvasir_seg/configs/kvasir-seg_512x512.py create mode 100644 projects/medical/2d_image/endoscopy/kvasir_seg/datasets/kvasir-seg_dataset.py create mode 100644 projects/medical/2d_image/endoscopy/kvasir_seg/tools/prepare_dataset.py diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg/README.md b/projects/medical/2d_image/endoscopy/kvasir_seg/README.md new file mode 100644 index 0000000000..ea597bc440 --- /dev/null +++ b/projects/medical/2d_image/endoscopy/kvasir_seg/README.md @@ -0,0 +1,145 @@ +# Kvasir-Sessile Dataset (Kvasir SEG) + +## Description + +This project supports **`Kvasir-Sessile Dataset (Kvasir SEG) `**, which can be downloaded from [here](https://opendatalab.com/Kvasir-Sessile_dataset). + +## Dataset Overview + +The Kvasir-SEG dataset contains polyp images and their corresponding ground truth from the Kvasir Dataset v2. The resolution of the images contained in Kvasir-SEG varies from 332x487 to 1920x1072 pixels. + + + +### Information Statistics + +| Dataset Name | Anatomical Region | Task Type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License | +| ------------------------------------------------------------- | ----------------- | ------------ | --------- | ------------ | --------------------- | ---------------------- | ------------ | --------------------------------------------------------- | +| [Kvarsir-SEG](https://opendatalab.com/Kvasir-Sessile_dataset) | abdomen | segmentation | endoscopy | 2 | 196/-/- | yes/-/- | 2020 | [CC-BY 4.0](https://creativecommons.org/licenses/by/4.0/) | + +| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test | +| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: | +| background | 196 | 92.31 | - | - | - | - | +| polyp | 196 | 7.69 | - | - | - | - | + +Note: + +- `Pct` means percentage of pixels in this category in all pixels. + +### Visualization + +![kvasir-seg](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/endoscopy_images/kvasir_seg/kvasir_seg_dataset.png?raw=true) + +### Dataset Citation + +``` +@inproceedings{jha2020kvasir, + title={Kvasir-seg: A segmented polyp dataset}, + author={Jha, Debesh and Smedsrud, Pia H and Riegler, Michael A and Halvorsen, P{\aa}l and Lange, Thomas de and Johansen, Dag and Johansen, H{\aa}vard D}, + booktitle={International Conference on Multimedia Modeling}, + pages={451--462}, + year={2020}, + organization={Springer} + } +``` + +### Prerequisites + +- Python v3.8 +- PyTorch v1.10.0 +- pillow(PIL) v9.3.0 +- scikit-learn(sklearn) v1.2.0 +- [MIM](https://github.com/open-mmlab/mim) v0.3.4 +- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4 +- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher +- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5 + +All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `kvasir_seg/` root directory, run the following line to add the current directory to `PYTHONPATH`: + +```shell +export PYTHONPATH=`pwd`:$PYTHONPATH +``` + +### Dataset Preparing + +- download dataset from [here](https://opendatalab.com/Kvasir-Sessile_dataset) and decompress data to path `'data/'`. +- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below. +- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set cannot be obtained, we generate `train.txt` and `val.txt` from the training set randomly. + +```none + mmsegmentation + ├── mmseg + ├── projects + │ ├── medical + │ │ ├── 2d_image + │ │ │ ├── endoscopy + │ │ │ │ ├── kvasir_seg + │ │ │ │ │ ├── configs + │ │ │ │ │ ├── datasets + │ │ │ │ │ ├── tools + │ │ │ │ │ ├── data + │ │ │ │ │ │ ├── train.txt + │ │ │ │ │ │ ├── val.txt + │ │ │ │ │ │ ├── images + │ │ │ │ │ │ │ ├── train + │ │ │ │ | │ │ │ ├── xxx.png + │ │ │ │ | │ │ │ ├── ... + │ │ │ │ | │ │ │ └── xxx.png + │ │ │ │ │ │ ├── masks + │ │ │ │ │ │ │ ├── train + │ │ │ │ | │ │ │ ├── xxx.png + │ │ │ │ | │ │ │ ├── ... + │ │ │ │ | │ │ │ └── xxx.png +``` + +### Divided Dataset Information + +***Note: The table information below is divided by ourselves.*** + +| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test | +| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: | +| background | 156 | 92.28 | 40 | 92.41 | - | - | +| polyp | 156 | 7.72 | 40 | 7.59 | - | - | + +### Training commands + +To train models on a single server with one GPU. (default) + +```shell +mim train mmseg .configs/${CONFIG_FILE} +``` + +### Testing commands + +To test models on a single server with one GPU. (default) + +```shell +mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH} +``` + + + +## Checklist + +- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`. + + - [x] Finish the code + - [x] Basic docstrings & proper citation + - [ ] Test-time correctness + - [x] A full README + +- [x] Milestone 2: Indicates a successful model implementation. + + - [x] Training-time correctness + +- [ ] Milestone 3: Good to be a part of our core package! + + - [ ] Type hints and docstrings + - [ ] Unit tests + - [ ] Code polishing + - [ ] Metafile.yml + +- [ ] Move your modules into the core package following the codebase's file hierarchy structure. + +- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure. diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_kvasir-seg-512x512.py b/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_kvasir-seg-512x512.py new file mode 100644 index 0000000000..145d5a7a17 --- /dev/null +++ b/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_kvasir-seg-512x512.py @@ -0,0 +1,18 @@ +_base_ = [ + 'mmseg::_base_/models/fcn_unet_s5-d16.py', './kvasir-seg_512x512.py', + 'mmseg::_base_/default_runtime.py', + 'mmseg::_base_/schedules/schedule_20k.py' +] +custom_imports = dict(imports='datasets.kvasir-seg_dataset') +img_scale = (512, 512) +data_preprocessor = dict(size=img_scale) +optimizer = dict(lr=0.01) +optim_wrapper = dict(optimizer=optimizer) +model = dict( + data_preprocessor=data_preprocessor, + decode_head=dict( + num_classes=2, loss_decode=dict(use_sigmoid=True), out_channels=1), + auxiliary_head=None, + test_cfg=dict(mode='whole', _delete_=True)) +vis_backends = None +visualizer = dict(vis_backends=vis_backends) diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_kvasir-seg-512x512.py b/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_kvasir-seg-512x512.py new file mode 100644 index 0000000000..3ea05c5109 --- /dev/null +++ b/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_kvasir-seg-512x512.py @@ -0,0 +1,17 @@ +_base_ = [ + 'mmseg::_base_/models/fcn_unet_s5-d16.py', './kvasir-seg_512x512.py', + 'mmseg::_base_/default_runtime.py', + 'mmseg::_base_/schedules/schedule_20k.py' +] +custom_imports = dict(imports='datasets.kvasir-seg_dataset') +img_scale = (512, 512) +data_preprocessor = dict(size=img_scale) +optimizer = dict(lr=0.0001) +optim_wrapper = dict(optimizer=optimizer) +model = dict( + data_preprocessor=data_preprocessor, + decode_head=dict(num_classes=2), + auxiliary_head=None, + test_cfg=dict(mode='whole', _delete_=True)) +vis_backends = None +visualizer = dict(vis_backends=vis_backends) diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_kvasir-seg-512x512.py b/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_kvasir-seg-512x512.py new file mode 100644 index 0000000000..7e064a716a --- /dev/null +++ b/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_kvasir-seg-512x512.py @@ -0,0 +1,17 @@ +_base_ = [ + 'mmseg::_base_/models/fcn_unet_s5-d16.py', './kvasir-seg_512x512.py', + 'mmseg::_base_/default_runtime.py', + 'mmseg::_base_/schedules/schedule_20k.py' +] +custom_imports = dict(imports='datasets.kvasir-seg_dataset') +img_scale = (512, 512) +data_preprocessor = dict(size=img_scale) +optimizer = dict(lr=0.001) +optim_wrapper = dict(optimizer=optimizer) +model = dict( + data_preprocessor=data_preprocessor, + decode_head=dict(num_classes=2), + auxiliary_head=None, + test_cfg=dict(mode='whole', _delete_=True)) +vis_backends = None +visualizer = dict(vis_backends=vis_backends) diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_kvasir-seg-512x512.py b/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_kvasir-seg-512x512.py new file mode 100644 index 0000000000..0fc1d6e99d --- /dev/null +++ b/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_kvasir-seg-512x512.py @@ -0,0 +1,17 @@ +_base_ = [ + 'mmseg::_base_/models/fcn_unet_s5-d16.py', './kvasir-seg_512x512.py', + 'mmseg::_base_/default_runtime.py', + 'mmseg::_base_/schedules/schedule_20k.py' +] +custom_imports = dict(imports='datasets.kvasir-seg_dataset') +img_scale = (512, 512) +data_preprocessor = dict(size=img_scale) +optimizer = dict(lr=0.01) +optim_wrapper = dict(optimizer=optimizer) +model = dict( + data_preprocessor=data_preprocessor, + decode_head=dict(num_classes=2), + auxiliary_head=None, + test_cfg=dict(mode='whole', _delete_=True)) +vis_backends = None +visualizer = dict(vis_backends=vis_backends) diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg/configs/kvasir-seg_512x512.py b/projects/medical/2d_image/endoscopy/kvasir_seg/configs/kvasir-seg_512x512.py new file mode 100644 index 0000000000..e8b2467f8c --- /dev/null +++ b/projects/medical/2d_image/endoscopy/kvasir_seg/configs/kvasir-seg_512x512.py @@ -0,0 +1,42 @@ +dataset_type = 'KvasirSEGDataset' +data_root = 'data/' +img_scale = (512, 512) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', scale=img_scale, keep_ratio=False), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='PackSegInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', scale=img_scale, keep_ratio=False), + dict(type='LoadAnnotations'), + dict(type='PackSegInputs') +] +train_dataloader = dict( + batch_size=16, + num_workers=4, + persistent_workers=True, + sampler=dict(type='InfiniteSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='train.txt', + data_prefix=dict(img_path='images/', seg_map_path='masks/'), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='val.txt', + data_prefix=dict(img_path='images/', seg_map_path='masks/'), + pipeline=test_pipeline)) +test_dataloader = val_dataloader +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice']) +test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice']) diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg/datasets/kvasir-seg_dataset.py b/projects/medical/2d_image/endoscopy/kvasir_seg/datasets/kvasir-seg_dataset.py new file mode 100644 index 0000000000..9d601328eb --- /dev/null +++ b/projects/medical/2d_image/endoscopy/kvasir_seg/datasets/kvasir-seg_dataset.py @@ -0,0 +1,30 @@ +from mmseg.datasets import BaseSegDataset +from mmseg.registry import DATASETS + + +@DATASETS.register_module() +class KvasirSEGDataset(BaseSegDataset): + """KvasirSEGDataset dataset. + + In segmentation map annotation for KvasirSEGDataset, 0 stands for + background, which is included in 2 categories. + ``reduce_zero_label`` is fixed to False. The ``img_suffix`` is + fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'. + Args: + img_suffix (str): Suffix of images. Default: '.png' + seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' + reduce_zero_label (bool): Whether to mark label zero as ignored. + Default to False.. + """ + METAINFO = dict(classes=('background', 'polyp')) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='.png', + reduce_zero_label=False, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg/tools/prepare_dataset.py b/projects/medical/2d_image/endoscopy/kvasir_seg/tools/prepare_dataset.py new file mode 100644 index 0000000000..74c43e9635 --- /dev/null +++ b/projects/medical/2d_image/endoscopy/kvasir_seg/tools/prepare_dataset.py @@ -0,0 +1,87 @@ +import glob +import os + +import numpy as np +from PIL import Image + +root_path = 'data/' +img_suffix = '.jpg' +seg_map_suffix = '.jpg' +save_img_suffix = '.png' +save_seg_map_suffix = '.png' +tgt_img_dir = os.path.join(root_path, 'images/train/') +tgt_mask_dir = os.path.join(root_path, 'masks/train/') +os.system('mkdir -p ' + tgt_img_dir) +os.system('mkdir -p ' + tgt_mask_dir) + + +def filter_suffix_recursive(src_dir, suffix): + # filter out file names and paths in source directory + suffix = '.' + suffix if '.' not in suffix else suffix + file_paths = glob.glob( + os.path.join(src_dir, '**', '*' + suffix), recursive=True) + file_names = [_.split('/')[-1] for _ in file_paths] + return sorted(file_paths), sorted(file_names) + + +def convert_label(img, convert_dict): + arr = np.zeros_like(img, dtype=np.uint8) + for c, i in convert_dict.items(): + arr[img == c] = i + return arr + + +def convert_pics_into_pngs(src_dir, tgt_dir, suffix, convert='RGB'): + if not os.path.exists(tgt_dir): + os.makedirs(tgt_dir) + + src_paths, src_names = filter_suffix_recursive(src_dir, suffix=suffix) + for i, (src_name, src_path) in enumerate(zip(src_names, src_paths)): + tgt_name = src_name.replace(suffix, save_img_suffix) + tgt_path = os.path.join(tgt_dir, tgt_name) + num = len(src_paths) + img = np.array(Image.open(src_path)) + if len(img.shape) == 2: + pil = Image.fromarray(img).convert(convert) + elif len(img.shape) == 3: + pil = Image.fromarray(img) + else: + raise ValueError('Input image not 2D/3D: ', img.shape) + + pil.save(tgt_path) + print(f'processed {i+1}/{num}.') + + +def convert_label_pics_into_pngs(src_dir, + tgt_dir, + suffix, + convert_dict={ + 0: 0, + 255: 1 + }): + if not os.path.exists(tgt_dir): + os.makedirs(tgt_dir) + + src_paths, src_names = filter_suffix_recursive(src_dir, suffix=suffix) + num = len(src_paths) + for i, (src_name, src_path) in enumerate(zip(src_names, src_paths)): + tgt_name = src_name.replace(suffix, save_seg_map_suffix) + tgt_path = os.path.join(tgt_dir, tgt_name) + + img = np.array(Image.open(src_path)) + img = convert_label(img, convert_dict) + Image.fromarray(img).save(tgt_path) + print(f'processed {i+1}/{num}.') + + +if __name__ == '__main__': + + convert_pics_into_pngs( + os.path.join(root_path, 'sessile-main-Kvasir-SEG/images'), + tgt_img_dir, + suffix=img_suffix) + + convert_label_pics_into_pngs( + os.path.join(root_path, 'sessile-main-Kvasir-SEG/masks'), + tgt_mask_dir, + suffix=seg_map_suffix)