Skip to content

Commit

Permalink
[Project] Medical semantic seg dataset: breast_cancer_cell_seg (#2726)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianbinli committed Jun 20, 2023
1 parent b5fc5ab commit 041f1f0
Show file tree
Hide file tree
Showing 7 changed files with 319 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Breast Cancer Cell Segmentation

## Description

This project support **`Breast Cancer Cell Segmentation`**, and the dataset used in this project can be downloaded from [here](https://tianchi.aliyun.com/dataset/dataDetail?dataId=90152).

### Dataset Overview

In this dataset, there are 58 H&E stained histopathology images used in breast cancer cell detection with associated ground truth data available. Routine histology uses the stain combination of hematoxylin and eosin, commonly referred to as H&E. These images are stained since most cells are essentially transparent, with little or no intrinsic pigment. Certain special stains, which bind selectively to particular components, are be used to identify biological structures such as cells. In those images, the challenging problem is cell segmentation for subsequent classification in benign and malignant cells.

### Original Statistic Information

| Dataset name | Anatomical region | Task type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
| --------------------------------------------------------------------------------------------- | ----------------- | ------------ | -------------- | ------------ | --------------------- | ---------------------- | ------------ | ------------------------------------------------------------------------------------------------------ |
| [Breast Cancer Cell Segmentation](https://tianchi.aliyun.com/dataset/dataDetail?dataId=90152) | thorax | segmentation | histopathology | 2 | 58/-/- | yes/-/- | 2021 | [CC-BY-SA-NC 4.0](http://creativecommons.org/licenses/by-sa/4.0/?spm=5176.12282016.0.0.3f5b5291ypBxb2) |

| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
| :----------------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
| normal | 58 | 98.37 | - | - | - | - |
| breast cancer cell | 58 | 1.63 | - | - | - | - |

Note:

- `Pct` means percentage of pixels in this category in all pixels.

### Visualization

![bac](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/histopathology/breast_cancer_cell_seg/breast_cancer_cell_seg_dataset.png)

## Dataset Citation

```
@inproceedings{gelasca2008evaluation,
title={Evaluation and benchmark for biological image segmentation},
author={Gelasca, Elisa Drelie and Byun, Jiyun and Obara, Boguslaw and Manjunath, BS},
booktitle={2008 15th IEEE international conference on image processing},
pages={1816--1819},
year={2008},
organization={IEEE}
}
```

### Prerequisites

- Python v3.8
- PyTorch v1.10.0
- [MIM](https://github.com/open-mmlab/mim) v0.3.4
- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5

All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `breast_cancer_cell_seg/` root directory, run the following line to add the current directory to `PYTHONPATH`:

```shell
export PYTHONPATH=`pwd`:$PYTHONPATH
```

### Dataset preparing

- download dataset from [here](https://tianchi.aliyun.com/dataset/dataDetail?dataId=90152) and decompression data to path `'data/'`.
- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below.
- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set can't be obtained, we generate `train.txt` and `val.txt` from the training set randomly.

```none
mmsegmentation
├── mmseg
├── projects
│ ├── medical
│ │ ├── 2d_image
│ │ │ ├── histopathology
│ │ │ │ ├── breast_cancer_cell_seg
│ │ │ │ │ ├── configs
│ │ │ │ │ ├── datasets
│ │ │ │ │ ├── tools
│ │ │ │ │ ├── data
│ │ │ │ │ │ ├── train.txt
│ │ │ │ │ │ ├── val.txt
│ │ │ │ │ │ ├── images
│ │ │ │ │ │ │ ├── train
│ │ │ │ | │ │ │ ├── xxx.png
│ │ │ │ | │ │ │ ├── ...
│ │ │ │ | │ │ │ └── xxx.png
│ │ │ │ │ │ ├── masks
│ │ │ │ │ │ │ ├── train
│ │ │ │ | │ │ │ ├── xxx.png
│ │ │ │ | │ │ │ ├── ...
│ │ │ │ | │ │ │ └── xxx.png
```

### Divided Dataset Information

***Note: The table information below is divided by ourselves.***

| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
| :----------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
| background | 46 | 98.36 | 12 | 98.41 | - | - |
| erythrocytes | 46 | 1.64 | 12 | 1.59 | - | - |

### Training commands

Train models on a single server with one GPU.

```shell
mim train mmseg ./configs/${CONFIG_FILE}
```

### Testing commands

Test models on a single server with one GPU.

```shell
mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH}
```

<!-- List the results as usually done in other model's README. [Example](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/configs/fcn#results-and-models)
You should claim whether this is based on the pre-trained weights, which are converted from the official release; or it's a reproduced result obtained from retraining the model in this project. -->

## Checklist

- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.

- [x] Finish the code

- [x] Basic docstrings & proper citation

- [x] Test-time correctness

- [x] A full README

- [ ] Milestone 2: Indicates a successful model implementation.

- [ ] Training-time correctness

- [ ] Milestone 3: Good to be a part of our core package!

- [ ] Type hints and docstrings

- [ ] Unit tests

- [ ] Code polishing

- [ ] Metafile.yml

- [ ] Move your modules into the core package following the codebase's file hierarchy structure.

- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
dataset_type = 'BreastCancerCellSegDataset'
data_root = 'data/'
img_scale = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', scale=img_scale, keep_ratio=False),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=img_scale, keep_ratio=False),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
train_dataloader = dict(
batch_size=16,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='train.txt',
data_prefix=dict(img_path='images/', seg_map_path='masks/'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='val.txt',
data_prefix=dict(img_path='images/', seg_map_path='masks/'),
pipeline=test_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
_base_ = [
'./breast-cancer-cell-seg_512x512.py',
'mmseg::_base_/models/fcn_unet_s5-d16.py',
'mmseg::_base_/default_runtime.py',
'mmseg::_base_/schedules/schedule_20k.py'
]
custom_imports = dict(imports='datasets.breast-cancer-cell-seg_dataset')
img_scale = (512, 512)
data_preprocessor = dict(size=img_scale)
optimizer = dict(lr=0.0001)
optim_wrapper = dict(optimizer=optimizer)
model = dict(
data_preprocessor=data_preprocessor,
decode_head=dict(num_classes=2),
auxiliary_head=None,
test_cfg=dict(mode='whole', _delete_=True))
vis_backends = None
visualizer = dict(vis_backends=vis_backends)
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
_base_ = [
'./breast-cancer-cell-seg_512x512.py',
'mmseg::_base_/models/fcn_unet_s5-d16.py',
'mmseg::_base_/default_runtime.py',
'mmseg::_base_/schedules/schedule_20k.py'
]
custom_imports = dict(imports='datasets.breast-cancer-cell-seg_dataset')
img_scale = (512, 512)
data_preprocessor = dict(size=img_scale)
optimizer = dict(lr=0.001)
optim_wrapper = dict(optimizer=optimizer)
model = dict(
data_preprocessor=data_preprocessor,
decode_head=dict(num_classes=2),
auxiliary_head=None,
test_cfg=dict(mode='whole', _delete_=True))
vis_backends = None
visualizer = dict(vis_backends=vis_backends)
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
_base_ = [
'./breast-cancer-cell-seg_512x512.py',
'mmseg::_base_/models/fcn_unet_s5-d16.py',
'mmseg::_base_/default_runtime.py',
'mmseg::_base_/schedules/schedule_20k.py'
]
custom_imports = dict(imports='datasets.breast-cancer-cell-seg_dataset')
img_scale = (512, 512)
data_preprocessor = dict(size=img_scale)
optimizer = dict(lr=0.01)
optim_wrapper = dict(optimizer=optimizer)
model = dict(
data_preprocessor=data_preprocessor,
decode_head=dict(num_classes=2),
auxiliary_head=None,
test_cfg=dict(mode='whole', _delete_=True))
vis_backends = None
visualizer = dict(vis_backends=vis_backends)
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from mmseg.datasets import BaseSegDataset
from mmseg.registry import DATASETS


@DATASETS.register_module()
class BreastCancerCellSegDataset(BaseSegDataset):
"""BreastCancerCellSegDataset dataset.
In segmentation map annotation for BreastCancerCellSegDataset,
``reduce_zero_label`` is fixed to False. The ``img_suffix``
is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
Args:
img_suffix (str): Suffix of images. Default: '.png'
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
reduce_zero_label (bool): Whether to mark label zero as ignored.
Default to False.
"""
METAINFO = dict(classes=('normal', 'breast cancer cell'))

def __init__(self,
img_suffix='.png',
seg_map_suffix='.png',
**kwargs) -> None:
super().__init__(
img_suffix=img_suffix,
seg_map_suffix=seg_map_suffix,
reduce_zero_label=False,
**kwargs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import glob
import os

import numpy as np
from PIL import Image

root_path = 'data/'
img_suffix = '.tif'
seg_map_suffix = '.TIF'
save_img_suffix = '.png'
save_seg_map_suffix = '.png'

x_train = glob.glob(
os.path.join('data/Breast Cancer Cell Segmentation_datasets/Images/*' +
img_suffix))

os.system('mkdir -p ' + root_path + 'images/train/')
os.system('mkdir -p ' + root_path + 'masks/train/')

D2_255_convert_dict = {0: 0, 255: 1}


def convert_2d(img, convert_dict=D2_255_convert_dict):
arr_2d = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
for c, i in convert_dict.items():
arr_2d[img == c] = i
return arr_2d


part_dir_dict = {0: 'train/'}
for ith, part in enumerate([x_train]):
part_dir = part_dir_dict[ith]
for img in part:
basename = os.path.basename(img)
img_save_path = root_path + 'images/' + part_dir + basename.split(
'.')[0] + save_img_suffix
Image.open(img).save(img_save_path)
mask_path = root_path + 'Breast Cancer Cell Segmentation_datasets/Masks/' + '_'.join( # noqa
basename.split('_')[:-1]) + seg_map_suffix
label = np.array(Image.open(mask_path))

save_mask_path = root_path + 'masks/' + part_dir + basename.split(
'.')[0] + save_seg_map_suffix
assert len(label.shape) == 2 and 255 in label and 1 not in label
mask = convert_2d(label)
mask = Image.fromarray(mask.astype(np.uint8))
mask.save(save_mask_path)

0 comments on commit 041f1f0

Please sign in to comment.