From e4db1f20c918118216255756edf3a04874f71d85 Mon Sep 17 00:00:00 2001 From: masaaki Date: Sun, 25 Jun 2023 15:57:40 +0800 Subject: [PATCH] [Project] Medical semantic seg dataset: Pannuke (#2683) --- .../2d_image/histopathology/pannuke/README.md | 152 ++++++++++++++++++ ...16-0.01-20k_bactteria-detection-512x512.py | 18 +++ ...6_unet_1xb16-0.0001-20k_pannuke-512x512.py | 17 ++ ...16_unet_1xb16-0.001-20k_pannuke-512x512.py | 17 ++ ...d16_unet_1xb16-0.01-20k_pannuke-512x512.py | 17 ++ .../pannuke/configs/pannuke_512x512.py | 42 +++++ .../pannuke/datasets/pannuke_dataset.py | 33 ++++ .../pannuke/tools/prepare_dataset.py | 49 ++++++ 8 files changed, 345 insertions(+) create mode 100644 projects/medical/2d_image/histopathology/pannuke/README.md create mode 100644 projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_bactteria-detection-512x512.py create mode 100644 projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_pannuke-512x512.py create mode 100644 projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_pannuke-512x512.py create mode 100644 projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_pannuke-512x512.py create mode 100644 projects/medical/2d_image/histopathology/pannuke/configs/pannuke_512x512.py create mode 100644 projects/medical/2d_image/histopathology/pannuke/datasets/pannuke_dataset.py create mode 100644 projects/medical/2d_image/histopathology/pannuke/tools/prepare_dataset.py diff --git a/projects/medical/2d_image/histopathology/pannuke/README.md b/projects/medical/2d_image/histopathology/pannuke/README.md new file mode 100644 index 0000000000..701adace92 --- /dev/null +++ b/projects/medical/2d_image/histopathology/pannuke/README.md @@ -0,0 +1,152 @@ +# Pan-Cancer Histology Dataset for Nuclei Instance Segmentation and Classification (PanNuke) + +## Description + +This project supports **`Pan-Cancer Histology Dataset for Nuclei Instance Segmentation and Classification (PanNuke)`**, which can be downloaded from [here](https://academictorrents.com/details/99f2c7b57b95500711e33f2ee4d14c9fd7c7366c). + +### Dataset Overview + +Semi automatically generated nuclei instance segmentation and classification dataset with exhaustive nuclei labels across 19 different tissue types. The dataset consists of 481 visual fields, of which 312 are randomly sampled from more than 20K whole slide images at different magnifications, from multiple data sources. In total the dataset contains 205,343 labeled nuclei, each with an instance segmentation mask. Models trained on pannuke can aid in whole slide image tissue type segmentation, and generalise to new tissues. PanNuke demonstrates one of the first successfully semi-automatically generated datasets. + +### Statistic Information + +| Dataset Name | Anatomical Region | Task Type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License | +| ---------------------------------------------------------------------------------------- | ----------------- | ------------ | -------------- | ------------ | --------------------- | ---------------------- | ------------ | --------------------------------------------------------------- | +| [Pannuke](https://academictorrents.com/details/99f2c7b57b95500711e33f2ee4d14c9fd7c7366c) | full_body | segmentation | histopathology | 6 | 7901/-/- | yes/-/- | 2019 | [CC-BY-NC 4.0](https://creativecommons.org/licenses/by-sa/4.0/) | + +| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test | +| :-----------------------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: | +| background | 7901 | 83.32 | - | - | - | - | +| neoplastic | 4190 | 8.64 | - | - | - | - | +| non-neoplastic epithelial | 4126 | 1.77 | - | - | - | - | +| inflammatory | 6137 | 3.73 | - | - | - | - | +| connective | 232 | 0.07 | - | - | - | - | +| dead | 1528 | 2.47 | - | - | - | - | + +Note: + +- `Pct` means percentage of pixels in this category in all pixels. + +### Visualization + +![pannuke](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/histopathology/pannuke/pannuke_dataset.png?raw=true) + +### Dataset Citation + +``` +@inproceedings{gamper2019pannuke, + title={PanNuke: an open pan-cancer histology dataset for nuclei instance segmentation and classification}, + author={Gamper, Jevgenij and Koohbanani, Navid Alemi and Benet, Ksenija and Khuram, Ali and Rajpoot, Nasir}, + booktitle={European Congress on Digital Pathology}, + pages={11--19}, + year={2019}, +} +``` + +### Prerequisites + +- Python v3.8 +- PyTorch v1.10.0 +- pillow(PIL) v9.3.0 9.3.0 +- scikit-learn(sklearn) v1.2.0 1.2.0 +- [MIM](https://github.com/open-mmlab/mim) v0.3.4 +- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4 +- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher +- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5 + +All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `pannuke/` root directory, run the following line to add the current directory to `PYTHONPATH`: + +```shell +export PYTHONPATH=`pwd`:$PYTHONPATH +``` + +### Dataset Preparing + +- download dataset from [here](https://academictorrents.com/details/99f2c7b57b95500711e33f2ee4d14c9fd7c7366c) and decompress data to path `'data/'`. +- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below. +- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set cannot be obtained, we generate `train.txt` and `val.txt` from the training set randomly. + +```none + mmsegmentation + ├── mmseg + ├── projects + │ ├── medical + │ │ ├── 2d_image + │ │ │ ├── histopathology + │ │ │ │ ├── pannuke + │ │ │ │ │ ├── configs + │ │ │ │ │ ├── datasets + │ │ │ │ │ ├── tools + │ │ │ │ │ ├── data + │ │ │ │ │ │ ├── train.txt + │ │ │ │ │ │ ├── val.txt + │ │ │ │ │ │ ├── images + │ │ │ │ │ │ │ ├── train + │ │ │ │ | │ │ │ ├── xxx.png + │ │ │ │ | │ │ │ ├── ... + │ │ │ │ | │ │ │ └── xxx.png + │ │ │ │ │ │ ├── masks + │ │ │ │ │ │ │ ├── train + │ │ │ │ | │ │ │ ├── xxx.png + │ │ │ │ | │ │ │ ├── ... + │ │ │ │ | │ │ │ └── xxx.png +``` + +### Divided Dataset Information + +***Note: The table information below is divided by ourselves.*** + +| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test | +| :-----------------------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: | +| background | 6320 | 83.38 | 1581 | 83.1 | - | - | +| neoplastic | 3339 | 8.55 | 851 | 9.0 | - | - | +| non-neoplastic epithelial | 3293 | 1.77 | 833 | 1.76 | - | - | +| inflammatory | 4914 | 3.72 | 1223 | 3.76 | - | - | +| connective | 170 | 0.06 | 62 | 0.09 | - | - | +| dead | 1235 | 2.51 | 293 | 2.29 | - | - | + +### Training commands + +To train models on a single server with one GPU. (default) + +```shell +mim train mmseg ./configs/${CONFIG_FILE} +``` + +### Testing commands + +To test models on a single server with one GPU. (default) + +```shell +mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH} +``` + + + +12x512 | 0.0001 | 58.87 | 62.42 | [config](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_pannuke-512x512.py) | + +## Checklist + +- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`. + + - [x] Finish the code + - [x] Basic docstrings & proper citation + - [ ] Test-time correctness + - [x] A full README + +- [ ] Milestone 2: Indicates a successful model implementation. + + - [ ] Training-time correctness + +- [ ] Milestone 3: Good to be a part of our core package! + + - [ ] Type hints and docstrings + - [ ] Unit tests + - [ ] Code polishing + - [ ] Metafile.yml + +- [ ] Move your modules into the core package following the codebase's file hierarchy structure. + +- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure. diff --git a/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_bactteria-detection-512x512.py b/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_bactteria-detection-512x512.py new file mode 100644 index 0000000000..92584e9a68 --- /dev/null +++ b/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_bactteria-detection-512x512.py @@ -0,0 +1,18 @@ +_base_ = [ + 'mmseg::_base_/models/fcn_unet_s5-d16.py', + './bactteria-detection_512x512.py', 'mmseg::_base_/default_runtime.py', + 'mmseg::_base_/schedules/schedule_20k.py' +] +custom_imports = dict(imports='datasets.bactteria-detection_dataset') +img_scale = (512, 512) +data_preprocessor = dict(size=img_scale) +optimizer = dict(lr=0.01) +optim_wrapper = dict(optimizer=optimizer) +model = dict( + data_preprocessor=data_preprocessor, + decode_head=dict( + num_classes=2, loss_decode=dict(use_sigmoid=True), out_channels=1), + auxiliary_head=None, + test_cfg=dict(mode='whole', _delete_=True)) +vis_backends = None +visualizer = dict(vis_backends=vis_backends) diff --git a/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_pannuke-512x512.py b/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_pannuke-512x512.py new file mode 100644 index 0000000000..042a08ce00 --- /dev/null +++ b/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_pannuke-512x512.py @@ -0,0 +1,17 @@ +_base_ = [ + 'mmseg::_base_/models/fcn_unet_s5-d16.py', './pannuke_512x512.py', + 'mmseg::_base_/default_runtime.py', + 'mmseg::_base_/schedules/schedule_20k.py' +] +custom_imports = dict(imports='datasets.pannuke_dataset') +img_scale = (512, 512) +data_preprocessor = dict(size=img_scale) +optimizer = dict(lr=0.0001) +optim_wrapper = dict(optimizer=optimizer) +model = dict( + data_preprocessor=data_preprocessor, + decode_head=dict(num_classes=6), + auxiliary_head=None, + test_cfg=dict(mode='whole', _delete_=True)) +vis_backends = None +visualizer = dict(vis_backends=vis_backends) diff --git a/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_pannuke-512x512.py b/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_pannuke-512x512.py new file mode 100644 index 0000000000..e92514c913 --- /dev/null +++ b/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_pannuke-512x512.py @@ -0,0 +1,17 @@ +_base_ = [ + 'mmseg::_base_/models/fcn_unet_s5-d16.py', './pannuke_512x512.py', + 'mmseg::_base_/default_runtime.py', + 'mmseg::_base_/schedules/schedule_20k.py' +] +custom_imports = dict(imports='datasets.pannuke_dataset') +img_scale = (512, 512) +data_preprocessor = dict(size=img_scale) +optimizer = dict(lr=0.001) +optim_wrapper = dict(optimizer=optimizer) +model = dict( + data_preprocessor=data_preprocessor, + decode_head=dict(num_classes=6), + auxiliary_head=None, + test_cfg=dict(mode='whole', _delete_=True)) +vis_backends = None +visualizer = dict(vis_backends=vis_backends) diff --git a/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_pannuke-512x512.py b/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_pannuke-512x512.py new file mode 100644 index 0000000000..a9403c849f --- /dev/null +++ b/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_pannuke-512x512.py @@ -0,0 +1,17 @@ +_base_ = [ + 'mmseg::_base_/models/fcn_unet_s5-d16.py', './pannuke_512x512.py', + 'mmseg::_base_/default_runtime.py', + 'mmseg::_base_/schedules/schedule_20k.py' +] +custom_imports = dict(imports='datasets.pannuke_dataset') +img_scale = (512, 512) +data_preprocessor = dict(size=img_scale) +optimizer = dict(lr=0.01) +optim_wrapper = dict(optimizer=optimizer) +model = dict( + data_preprocessor=data_preprocessor, + decode_head=dict(num_classes=6), + auxiliary_head=None, + test_cfg=dict(mode='whole', _delete_=True)) +vis_backends = None +visualizer = dict(vis_backends=vis_backends) diff --git a/projects/medical/2d_image/histopathology/pannuke/configs/pannuke_512x512.py b/projects/medical/2d_image/histopathology/pannuke/configs/pannuke_512x512.py new file mode 100644 index 0000000000..316ac1ac44 --- /dev/null +++ b/projects/medical/2d_image/histopathology/pannuke/configs/pannuke_512x512.py @@ -0,0 +1,42 @@ +dataset_type = 'PanNukeDataset' +data_root = 'data/' +img_scale = (512, 512) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', scale=img_scale, keep_ratio=False), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='PackSegInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', scale=img_scale, keep_ratio=False), + dict(type='LoadAnnotations'), + dict(type='PackSegInputs') +] +train_dataloader = dict( + batch_size=16, + num_workers=4, + persistent_workers=True, + sampler=dict(type='InfiniteSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='train.txt', + data_prefix=dict(img_path='images/', seg_map_path='masks/'), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='val.txt', + data_prefix=dict(img_path='images/', seg_map_path='masks/'), + pipeline=test_pipeline)) +test_dataloader = val_dataloader +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice']) +test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice']) diff --git a/projects/medical/2d_image/histopathology/pannuke/datasets/pannuke_dataset.py b/projects/medical/2d_image/histopathology/pannuke/datasets/pannuke_dataset.py new file mode 100644 index 0000000000..4d3c687ff3 --- /dev/null +++ b/projects/medical/2d_image/histopathology/pannuke/datasets/pannuke_dataset.py @@ -0,0 +1,33 @@ +from mmseg.datasets import BaseSegDataset +from mmseg.registry import DATASETS + + +@DATASETS.register_module() +class PanNukeDataset(BaseSegDataset): + """PanNukeDataset dataset. + + In segmentation map annotation for PanNukeDataset, + 0 stands for background, which is included in 6 categories. + ``reduce_zero_label`` is fixed to False. The ``img_suffix`` + is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'. + + Args: + img_suffix (str): Suffix of images. Default: '.png' + seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' + reduce_zero_label (bool): Whether to mark label zero as ignored. + Default to False. + """ + METAINFO = dict( + classes=('background', 'neoplastic', 'non-neoplastic epithelial', + 'inflammatory', 'connective', 'dead')) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='.png', + reduce_zero_label=False, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) diff --git a/projects/medical/2d_image/histopathology/pannuke/tools/prepare_dataset.py b/projects/medical/2d_image/histopathology/pannuke/tools/prepare_dataset.py new file mode 100644 index 0000000000..7213b181f4 --- /dev/null +++ b/projects/medical/2d_image/histopathology/pannuke/tools/prepare_dataset.py @@ -0,0 +1,49 @@ +import os + +import numpy as np +from PIL import Image + +root_path = 'data/' + +tgt_img_dir = os.path.join(root_path, 'images/train') +tgt_mask_dir = os.path.join(root_path, 'masks/train') +os.system('mkdir -p ' + tgt_img_dir) +os.system('mkdir -p ' + tgt_mask_dir) + +fold_img_paths = sorted([ + os.path.join(root_path, 'pannuke/Fold 1/images/fold1/images.npy'), + os.path.join(root_path, 'pannuke/Fold 2/images/fold2/images.npy'), + os.path.join(root_path, 'pannuke/Fold 3/images/fold3/images.npy') +]) + +fold_mask_paths = sorted([ + os.path.join(root_path, 'pannuke/Fold 1/masks/fold1/masks.npy'), + os.path.join(root_path, 'pannuke/Fold 2/masks/fold2/masks.npy'), + os.path.join(root_path, 'pannuke/Fold 3/masks/fold3/masks.npy') +]) + +for n, (img_path, + mask_path) in enumerate(zip(fold_img_paths, fold_mask_paths)): + fold_name = str(n + 1) + imgs = np.load(img_path) + masks = np.load(mask_path) + + for i in range(imgs.shape[0]): + img = np.uint8(imgs[i]) + mask_multichannel = np.minimum(np.uint8(masks[i]), 1) + mask = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8) + for j in range(mask_multichannel.shape[-1]): + factor = (j + 1) % mask_multichannel.shape[-1] + # convert [0,1,2,3,4,5] to [1,2,3,4,5,0], + # with the last label being background + mask[mask_multichannel[..., j] == 1] = factor + + file_name = 'fold' + fold_name + '_' + str(i).rjust(4, '0') + '.png' + print('Processing: ', file_name) + tgt_img_path = os.path.join(tgt_img_dir, file_name) + tgt_mask_path = os.path.join(tgt_mask_dir, file_name) + Image.fromarray(img).save(tgt_img_path) + Image.fromarray(mask).save(tgt_mask_path) + + del imgs + del masks