From 5a9cfa919384d9f17ec03f930f38fad60fa7fe9f Mon Sep 17 00:00:00 2001 From: masaaki Date: Sun, 25 Jun 2023 12:57:18 +0800 Subject: [PATCH] [Project] Medical dataset: Kvasir seg aliyun (#2678) --- .../endoscopy/kvasir_seg_aliyun/README.md | 145 ++++++++++++++++++ ...16-0.0001-20k_kvasir-seg-aliyun-512x512.py | 17 ++ ...b16-0.001-20k_kvasir-seg-aliyun-512x512.py | 17 ++ ...xb16-0.01-20k_kvasir-seg-aliyun-512x512.py | 17 ++ ...r-sigmoid-20k_kvasir-seg-aliyun-512x512.py | 18 +++ .../configs/kvasir-seg-aliyun_512x512.py | 42 +++++ .../datasets/kvasir-seg-aliyun_dataset.py | 30 ++++ .../tools/prepare_dataset.py | 86 +++++++++++ 8 files changed, 372 insertions(+) create mode 100644 projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/README.md create mode 100644 projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_kvasir-seg-aliyun-512x512.py create mode 100644 projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_kvasir-seg-aliyun-512x512.py create mode 100644 projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_kvasir-seg-aliyun-512x512.py create mode 100644 projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.01lr-sigmoid-20k_kvasir-seg-aliyun-512x512.py create mode 100644 projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/kvasir-seg-aliyun_512x512.py create mode 100644 projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/datasets/kvasir-seg-aliyun_dataset.py create mode 100644 projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/tools/prepare_dataset.py diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/README.md b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/README.md new file mode 100644 index 0000000000..80eb00f51b --- /dev/null +++ b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/README.md @@ -0,0 +1,145 @@ +# Kvasir-SEG Segmented Polyp Dataset from Aliyun (Kvasir SEG Aliyun) + +## Description + +This project supports **`Kvasir-SEG Segmented Polyp Dataset from Aliyun (Kvasir SEG Aliyun) `**, which can be downloaded from [here](https://tianchi.aliyun.com/dataset/84385). + +### Dataset Overview + +Colorectal cancer is the second most common cancer type among women and third most common among men. Polyps are precursors to colorectal cancer and therefore important to detect and remove at an early stage. Polyps are found in nearly half of the individuals at age 50 that undergo a colonoscopy screening, and their frequency increase with age.Polyps are abnormal tissue growth from the mucous membrane, which is lining the inside of the GI tract, and can sometimes be cancerous. Colonoscopy is the gold standard for detection and assessment of these polyps with subsequent biopsy and removal of the polyps. Early disease detection has a huge impact on survival from colorectal cancer. Increasing the detection of polyps has been shown to decrease risk of colorectal cancer. Thus, automatic detection of more polyps at an early stage can play a crucial role in prevention and survival from colorectal cancer. + +The Kvasir-SEG dataset is based on the previous Kvasir dataset, which is the first multi-class dataset for gastrointestinal (GI) tract disease detection and classification. It contains annotated polyp images and their corresponding masks. The pixels depicting polyp tissue, the ROI, are represented by the foreground (white mask), while the background (in black) does not contain positive pixels. These images were collected and verified by experienced gastroenterologists from Vestre Viken Health Trust in Norway. The classes include anatomical landmarks, pathological findings and endoscopic procedures. + +### Information Statistics + +| Dataset Name | Anatomical Region | Task Type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License | +| ------------------------------------------------------ | ----------------- | ------------ | --------- | ------------ | --------------------- | ---------------------- | ------------ | --------------------------------------------------------- | +| [kvasir-seg](https://tianchi.aliyun.com/dataset/84385) | abdomen | segmentation | endoscopy | 2 | 1000/-/- | yes/-/- | 2020 | [CC-BY 4.0](https://creativecommons.org/licenses/by/4.0/) | + +| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test | +| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: | +| background | 1000 | 84.72 | - | - | - | - | +| polyp | 1000 | 15.28 | - | - | - | - | + +Note: + +- `Pct` means percentage of pixels in this category in all pixels. + +### Visualization + +![kvasir_seg_aliyun](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/endoscopy_images/kvasir_seg_aliyun/kvasir_seg_aliyun_dataset.png?raw=true) + +### Dataset Citation + +``` +@inproceedings{jha2020kvasir, + title={Kvasir-seg: A segmented polyp dataset}, + author={Jha, Debesh and Smedsrud, Pia H and Riegler, Michael A and Halvorsen, P{\aa}l and Lange, Thomas de and Johansen, Dag and Johansen, H{\aa}vard D}, + booktitle={International Conference on Multimedia Modeling}, + pages={451--462}, + year={2020}, + organization={Springer} + } +``` + +### Prerequisites + +- Python v3.8 +- PyTorch v1.10.0 +- pillow(PIL) v9.3.0 +- scikit-learn(sklearn) v1.2.0 +- [MIM](https://github.com/open-mmlab/mim) v0.3.4 +- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4 +- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher +- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5 + +All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `kvasir_seg_aliyun/` root directory, run the following line to add the current directory to `PYTHONPATH`: + +```shell +export PYTHONPATH=`pwd`:$PYTHONPATH +``` + +### Dataset Preparing + +- download dataset from [here](https://tianchi.aliyun.com/dataset/84385) and decompression data to path 'data/.'. +- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below. +- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set cannot be obtained, we generate `train.txt` and `val.txt` from the training set randomly. + +```none + mmsegmentation + ├── mmseg + ├── projects + │ ├── medical + │ │ ├── 2d_image + │ │ │ ├── endoscopy + │ │ │ │ ├── kvasir_seg_aliyun + │ │ │ │ │ ├── configs + │ │ │ │ │ ├── datasets + │ │ │ │ │ ├── tools + │ │ │ │ │ ├── data + │ │ │ │ │ │ ├── train.txt + │ │ │ │ │ │ ├── val.txt + │ │ │ │ │ │ ├── images + │ │ │ │ │ │ │ ├── train + │ │ │ │ | │ │ │ ├── xxx.png + │ │ │ │ | │ │ │ ├── ... + │ │ │ │ | │ │ │ └── xxx.png + │ │ │ │ │ │ ├── masks + │ │ │ │ │ │ │ ├── train + │ │ │ │ | │ │ │ ├── xxx.png + │ │ │ │ | │ │ │ ├── ... + │ │ │ │ | │ │ │ └── xxx.png +``` + +### Divided Dataset Information + +***Note: The table information below is divided by ourselves.*** + +| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test | +| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: | +| background | 800 | 84.66 | 200 | 84.94 | - | - | +| polyp | 800 | 15.34 | 200 | 15.06 | - | - | + +### Training commands + +To train models on a single server with one GPU. (default) + +```shell +mim train mmseg ./configs/${CONFIG_FILE} +``` + +### Testing commands + +To test models on a single server with one GPU. (default) + +```shell +mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH} +``` + + + +## Checklist + +- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`. + + - [x] Finish the code + - [x] Basic docstrings & proper citation + - [ ] Test-time correctness + - [x] A full README + +- [ ] Milestone 2: Indicates a successful model implementation. + + - [ ] Training-time correctness + +- [ ] Milestone 3: Good to be a part of our core package! + + - [ ] Type hints and docstrings + - [ ] Unit tests + - [ ] Code polishing + - [ ] Metafile.yml + +- [ ] Move your modules into the core package following the codebase's file hierarchy structure. + +- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure. diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_kvasir-seg-aliyun-512x512.py b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_kvasir-seg-aliyun-512x512.py new file mode 100644 index 0000000000..b59db95232 --- /dev/null +++ b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_kvasir-seg-aliyun-512x512.py @@ -0,0 +1,17 @@ +_base_ = [ + 'mmseg::_base_/models/fcn_unet_s5-d16.py', + './kvasir-seg-aliyun_512x512.py', 'mmseg::_base_/default_runtime.py', + 'mmseg::_base_/schedules/schedule_20k.py' +] +custom_imports = dict(imports='datasets.kvasir-seg-aliyun_dataset') +img_scale = (512, 512) +data_preprocessor = dict(size=img_scale) +optimizer = dict(lr=0.0001) +optim_wrapper = dict(optimizer=optimizer) +model = dict( + data_preprocessor=data_preprocessor, + decode_head=dict(num_classes=2), + auxiliary_head=None, + test_cfg=dict(mode='whole', _delete_=True)) +vis_backends = None +visualizer = dict(vis_backends=vis_backends) diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_kvasir-seg-aliyun-512x512.py b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_kvasir-seg-aliyun-512x512.py new file mode 100644 index 0000000000..6c526680cd --- /dev/null +++ b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_kvasir-seg-aliyun-512x512.py @@ -0,0 +1,17 @@ +_base_ = [ + 'mmseg::_base_/models/fcn_unet_s5-d16.py', + './kvasir-seg-aliyun_512x512.py', 'mmseg::_base_/default_runtime.py', + 'mmseg::_base_/schedules/schedule_20k.py' +] +custom_imports = dict(imports='datasets.kvasir-seg-aliyun_dataset') +img_scale = (512, 512) +data_preprocessor = dict(size=img_scale) +optimizer = dict(lr=0.001) +optim_wrapper = dict(optimizer=optimizer) +model = dict( + data_preprocessor=data_preprocessor, + decode_head=dict(num_classes=2), + auxiliary_head=None, + test_cfg=dict(mode='whole', _delete_=True)) +vis_backends = None +visualizer = dict(vis_backends=vis_backends) diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_kvasir-seg-aliyun-512x512.py b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_kvasir-seg-aliyun-512x512.py new file mode 100644 index 0000000000..a192a5bd24 --- /dev/null +++ b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_kvasir-seg-aliyun-512x512.py @@ -0,0 +1,17 @@ +_base_ = [ + 'mmseg::_base_/models/fcn_unet_s5-d16.py', + './kvasir-seg-aliyun_512x512.py', 'mmseg::_base_/default_runtime.py', + 'mmseg::_base_/schedules/schedule_20k.py' +] +custom_imports = dict(imports='datasets.kvasir-seg-aliyun_dataset') +img_scale = (512, 512) +data_preprocessor = dict(size=img_scale) +optimizer = dict(lr=0.01) +optim_wrapper = dict(optimizer=optimizer) +model = dict( + data_preprocessor=data_preprocessor, + decode_head=dict(num_classes=2), + auxiliary_head=None, + test_cfg=dict(mode='whole', _delete_=True)) +vis_backends = None +visualizer = dict(vis_backends=vis_backends) diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.01lr-sigmoid-20k_kvasir-seg-aliyun-512x512.py b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.01lr-sigmoid-20k_kvasir-seg-aliyun-512x512.py new file mode 100644 index 0000000000..5325e1f080 --- /dev/null +++ b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.01lr-sigmoid-20k_kvasir-seg-aliyun-512x512.py @@ -0,0 +1,18 @@ +_base_ = [ + 'mmseg::_base_/models/fcn_unet_s5-d16.py', + './kvasir-seg-aliyun_512x512.py', 'mmseg::_base_/default_runtime.py', + 'mmseg::_base_/schedules/schedule_20k.py' +] +custom_imports = dict(imports='datasets.kvasir-seg-aliyun_dataset') +img_scale = (512, 512) +data_preprocessor = dict(size=img_scale) +optimizer = dict(lr=0.01) +optim_wrapper = dict(optimizer=optimizer) +model = dict( + data_preprocessor=data_preprocessor, + decode_head=dict( + num_classes=2, loss_decode=dict(use_sigmoid=True), out_channels=1), + auxiliary_head=None, + test_cfg=dict(mode='whole', _delete_=True)) +vis_backends = None +visualizer = dict(vis_backends=vis_backends) diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/kvasir-seg-aliyun_512x512.py b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/kvasir-seg-aliyun_512x512.py new file mode 100644 index 0000000000..5f86880467 --- /dev/null +++ b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/kvasir-seg-aliyun_512x512.py @@ -0,0 +1,42 @@ +dataset_type = 'KvasirSEGAliyunDataset' +data_root = 'data/' +img_scale = (512, 512) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', scale=img_scale, keep_ratio=False), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='PackSegInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', scale=img_scale, keep_ratio=False), + dict(type='LoadAnnotations'), + dict(type='PackSegInputs') +] +train_dataloader = dict( + batch_size=16, + num_workers=4, + persistent_workers=True, + sampler=dict(type='InfiniteSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='train.txt', + data_prefix=dict(img_path='images/', seg_map_path='masks/'), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='val.txt', + data_prefix=dict(img_path='images/', seg_map_path='masks/'), + pipeline=test_pipeline)) +test_dataloader = val_dataloader +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice']) +test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice']) diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/datasets/kvasir-seg-aliyun_dataset.py b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/datasets/kvasir-seg-aliyun_dataset.py new file mode 100644 index 0000000000..198caf07bc --- /dev/null +++ b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/datasets/kvasir-seg-aliyun_dataset.py @@ -0,0 +1,30 @@ +from mmseg.datasets import BaseSegDataset +from mmseg.registry import DATASETS + + +@DATASETS.register_module() +class KvasirSEGAliyunDataset(BaseSegDataset): + """KvasirSEGAliyunDataset dataset. + + In segmentation map annotation for KvasirSEGAliyunDataset, + 0 stands for background,which is included in 2 categories. + ``reduce_zero_label`` is fixed to False. The ``img_suffix`` + is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'. + Args: + img_suffix (str): Suffix of images. Default: '.png' + seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' + reduce_zero_label (bool): Whether to mark label zero as ignored. + Default to False.. + """ + METAINFO = dict(classes=('background', 'polyp')) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='.png', + reduce_zero_label=False, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/tools/prepare_dataset.py b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/tools/prepare_dataset.py new file mode 100644 index 0000000000..b230e7fef5 --- /dev/null +++ b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/tools/prepare_dataset.py @@ -0,0 +1,86 @@ +import glob +import os + +import numpy as np +from PIL import Image + +root_path = 'data/' +img_suffix = '.jpg' +seg_map_suffix = '.jpg' +save_img_suffix = '.png' +save_seg_map_suffix = '.png' +tgt_img_dir = os.path.join(root_path, 'images/train/') +tgt_mask_dir = os.path.join(root_path, 'masks/train/') +os.system('mkdir -p ' + tgt_img_dir) +os.system('mkdir -p ' + tgt_mask_dir) + + +def filter_suffix_recursive(src_dir, suffix): + # filter out file names and paths in source directory + suffix = '.' + suffix if '.' not in suffix else suffix + file_paths = glob.glob( + os.path.join(src_dir, '**', '*' + suffix), recursive=True) + file_names = [_.split('/')[-1] for _ in file_paths] + return sorted(file_paths), sorted(file_names) + + +def convert_label(img, convert_dict): + arr = np.zeros_like(img, dtype=np.uint8) + for c, i in convert_dict.items(): + arr[img == c] = i + return arr + + +def convert_pics_into_pngs(src_dir, tgt_dir, suffix, convert='RGB'): + if not os.path.exists(tgt_dir): + os.makedirs(tgt_dir) + + src_paths, src_names = filter_suffix_recursive(src_dir, suffix=suffix) + for i, (src_name, src_path) in enumerate(zip(src_names, src_paths)): + tgt_name = src_name.replace(suffix, save_img_suffix) + tgt_path = os.path.join(tgt_dir, tgt_name) + num = len(src_paths) + img = np.array(Image.open(src_path)) + if len(img.shape) == 2: + pil = Image.fromarray(img).convert(convert) + elif len(img.shape) == 3: + pil = Image.fromarray(img) + else: + raise ValueError('Input image not 2D/3D: ', img.shape) + + pil.save(tgt_path) + print(f'processed {i+1}/{num}.') + + +def convert_label_pics_into_pngs(src_dir, + tgt_dir, + suffix, + convert_dict={ + 0: 0, + 255: 1 + }): + if not os.path.exists(tgt_dir): + os.makedirs(tgt_dir) + + src_paths, src_names = filter_suffix_recursive(src_dir, suffix=suffix) + num = len(src_paths) + for i, (src_name, src_path) in enumerate(zip(src_names, src_paths)): + tgt_name = src_name.replace(suffix, save_seg_map_suffix) + tgt_path = os.path.join(tgt_dir, tgt_name) + + img = np.array(Image.open(src_path).convert('L')) + img = convert_label(img, convert_dict) + Image.fromarray(img).save(tgt_path) + print(f'processed {i+1}/{num}.') + + +if __name__ == '__main__': + convert_pics_into_pngs( + os.path.join(root_path, 'Kvasir-SEG/images'), + tgt_img_dir, + suffix=img_suffix) + + convert_label_pics_into_pngs( + os.path.join(root_path, 'Kvasir-SEG/masks'), + tgt_mask_dir, + suffix=seg_map_suffix)