Skip to content

Commit

Permalink
[Project] Medical semantic seg dataset: orvs (#2728)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianbinli committed Jun 21, 2023
1 parent 6333dc1 commit 78e036c
Show file tree
Hide file tree
Showing 7 changed files with 315 additions and 0 deletions.
140 changes: 140 additions & 0 deletions projects/medical/2d_image/fundus_photography/orvs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# ORVS (Online Retinal image for Vessel Segmentation (ORVS))

## Description

This project supports **`ORVS (Online Retinal image for Vessel Segmentation (ORVS))`**, which can be downloaded from [here](https://opendatalab.org.cn/ORVS).

### Dataset Overview

The ORVS dataset is a newly established collaboration between the Department of Computer Science and the Department of Vision Science at the University of Calgary. The dataset contains 49 images collected from a clinic in Calgary, Canada, consisting of 42 training images and 7 testing images. All images were obtained using a Zeiss Visucam 200 with a 30-degree field of view (FOV). The image size is 1444×1444 pixels with 24 bits per pixel. The images are stored in JPEG format with low compression, which is common in ophthalmic practice. All images were manually traced by an expert who has been working in the field of retinal image analysis and has been trained to mark all pixels belonging to retinal vessels. The Windows Paint 3D tool was used for manual image annotation.

### Original Statistic Information

| Dataset name | Anatomical region | Task type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
| ------------------------------------------------------ | ----------------- | ------------ | ------------------ | ------------ | --------------------- | ---------------------- | ------------ | ------- |
| [Bactteria detection](https://opendatalab.org.cn/ORVS) | bacteria | segmentation | fundus photography | 2 | 130/-/72 | yes/-/yes | 2020 | - |

| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
| background | 130 | 94.83 | - | - | 72 | 94.25 |
| vessel | 130 | 5.17 | - | - | 72 | 5.75 |

Note:

- `Pct` means percentage of pixels in this category in all pixels.

### Visualization

![bac](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/fundus_photography/orvs/ORVS_dataset.png)

### Prerequisites

- Python v3.8
- PyTorch v1.10.0
- [MIM](https://github.com/open-mmlab/mim) v0.3.4
- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5

All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `orvs/` root directory, run the following line to add the current directory to `PYTHONPATH`:

```shell
export PYTHONPATH=`pwd`:$PYTHONPATH
```

### Dataset preparing

- Clone this [repository](https://github.com/AbdullahSarhan/ICPRVessels), then move `Vessels-Datasets` to `data/`.
- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below.
- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set can't be obtained, we generate `train.txt` and `val.txt` from the training set randomly.

```none
mmsegmentation
├── mmseg
├── projects
│ ├── medical
│ │ ├── 2d_image
│ │ │ ├── fundus_photography
│ │ │ │ ├── orvs
│ │ │ │ │ ├── configs
│ │ │ │ │ ├── datasets
│ │ │ │ │ ├── tools
│ │ │ │ │ ├── data
│ │ │ │ │ │ ├── train.txt
│ │ │ │ │ │ ├── test.txt
│ │ │ │ │ │ ├── images
│ │ │ │ │ │ │ ├── train
│ │ │ │ | │ │ │ ├── xxx.png
│ │ │ │ | │ │ │ ├── ...
│ │ │ │ | │ │ │ └── xxx.png
│ │ │ │ │ │ ├── masks
│ │ │ │ │ │ │ ├── train
│ │ │ │ | │ │ │ ├── xxx.png
│ │ │ │ | │ │ │ ├── ...
│ │ │ │ | │ │ │ └── xxx.png
```

### Training commands

Train models on a single server with one GPU.

```shell
mim train mmseg ./configs/${CONFIG_FILE}
```

### Testing commands

Test models on a single server with one GPU.

```shell
mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH}
```

<!-- List the results as usually done in other model's README. [Example](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/configs/fcn#results-and-models)
You should claim whether this is based on the pre-trained weights, which are converted from the official release; or it's a reproduced result obtained from retraining the model in this project. -->

## Dataset Citation

If this work is helpful for your research, please consider citing the below paper.

```
@inproceedings{sarhan2021transfer,
title={Transfer learning through weighted loss function and group normalization for vessel segmentation from retinal images},
author={Sarhan, Abdullah and Rokne, Jon and Alhajj, Reda and Crichton, Andrew},
booktitle={2020 25th International Conference on Pattern Recognition (ICPR)},
pages={9211--9218},
year={2021},
organization={IEEE}
}
```

## Checklist

- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.

- [x] Finish the code

- [x] Basic docstrings & proper citation

- [ ] Test-time correctness

- [x] A full README

- [ ] Milestone 2: Indicates a successful model implementation.

- [ ] Training-time correctness

- [ ] Milestone 3: Good to be a part of our core package!

- [ ] Type hints and docstrings

- [ ] Unit tests

- [ ] Code polishing

- [ ] Metafile.yml

- [ ] Move your modules into the core package following the codebase's file hierarchy structure.

- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
_base_ = [
'./orvs_512x512.py', 'mmseg::_base_/models/fcn_unet_s5-d16.py',
'mmseg::_base_/default_runtime.py',
'mmseg::_base_/schedules/schedule_20k.py'
]
custom_imports = dict(imports='datasets.orvs_dataset')
img_scale = (512, 512)
data_preprocessor = dict(size=img_scale)
optimizer = dict(lr=0.0001)
optim_wrapper = dict(optimizer=optimizer)
model = dict(
data_preprocessor=data_preprocessor,
decode_head=dict(num_classes=2),
auxiliary_head=None,
test_cfg=dict(mode='whole', _delete_=True))
vis_backends = None
visualizer = dict(vis_backends=vis_backends)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
_base_ = [
'./orvs_512x512.py', 'mmseg::_base_/models/fcn_unet_s5-d16.py',
'mmseg::_base_/default_runtime.py',
'mmseg::_base_/schedules/schedule_20k.py'
]
custom_imports = dict(imports='datasets.orvs_dataset')
img_scale = (512, 512)
data_preprocessor = dict(size=img_scale)
optimizer = dict(lr=0.001)
optim_wrapper = dict(optimizer=optimizer)
model = dict(
data_preprocessor=data_preprocessor,
decode_head=dict(num_classes=2),
auxiliary_head=None,
test_cfg=dict(mode='whole', _delete_=True))
vis_backends = None
visualizer = dict(vis_backends=vis_backends)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
_base_ = [
'./orvs_512x512.py', 'mmseg::_base_/models/fcn_unet_s5-d16.py',
'mmseg::_base_/default_runtime.py',
'mmseg::_base_/schedules/schedule_20k.py'
]
custom_imports = dict(imports='datasets.orvs_dataset')
img_scale = (512, 512)
data_preprocessor = dict(size=img_scale)
optimizer = dict(lr=0.01)
optim_wrapper = dict(optimizer=optimizer)
model = dict(
data_preprocessor=data_preprocessor,
decode_head=dict(num_classes=2),
auxiliary_head=None,
test_cfg=dict(mode='whole', _delete_=True))
vis_backends = None
visualizer = dict(vis_backends=vis_backends)
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
dataset_type = 'ORVSDataset'
data_root = 'data/'
img_scale = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', scale=img_scale, keep_ratio=False),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=img_scale, keep_ratio=False),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
train_dataloader = dict(
batch_size=16,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='train.txt',
data_prefix=dict(img_path='images/', seg_map_path='masks/'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='test.txt',
data_prefix=dict(img_path='images/', seg_map_path='masks/'),
pipeline=test_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from mmseg.datasets import BaseSegDataset
from mmseg.registry import DATASETS


@DATASETS.register_module()
class ORVSDataset(BaseSegDataset):
"""ORVSDataset dataset.
In segmentation map annotation for ORVSDataset,
``reduce_zero_label`` is fixed to False. The ``img_suffix``
is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
Args:
img_suffix (str): Suffix of images. Default: '.png'
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
"""
METAINFO = dict(classes=('background', 'vessel'))

def __init__(self,
img_suffix='.png',
seg_map_suffix='.png',
**kwargs) -> None:
super().__init__(
img_suffix=img_suffix,
seg_map_suffix=seg_map_suffix,
reduce_zero_label=False,
**kwargs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import glob
import os

import numpy as np
from PIL import Image

root_path = 'data/'
img_suffix = '.jpg'
seg_map_suffix_list = ['.jpg', '.png', '.tif']
save_img_suffix = '.png'
save_seg_map_suffix = '.png'

x_train = glob.glob(
os.path.join('data/Vessels-Datasets/*/Train/Original/Images/*' +
img_suffix))
x_test = glob.glob(
os.path.join('data/Vessels-Datasets/*/Test/Original/Images/*' +
img_suffix))

os.system('mkdir -p ' + root_path + 'images/train/')
os.system('mkdir -p ' + root_path + 'images/test/')
os.system('mkdir -p ' + root_path + 'masks/train/')
os.system('mkdir -p ' + root_path + 'masks/test/')

part_dir_dict = {0: 'train/', 1: 'test/'}
for ith, part in enumerate([x_train, x_test]):
part_dir = part_dir_dict[ith]
for img in part:
type_name = img.split('/')[-5]
basename = type_name + '_' + os.path.basename(img)
save_img_path = root_path + 'images/' + part_dir + basename.split(
'.')[0] + save_img_suffix
Image.open(img).save(save_img_path)

for seg_map_suffix in seg_map_suffix_list:
if os.path.exists('/'.join(img.split('/')[:-1]).replace(
'Images', 'Labels')):
mask_path = img.replace('Images', 'Labels').replace(
img_suffix, seg_map_suffix)
else:
mask_path = img.replace('Images', 'labels').replace(
img_suffix, seg_map_suffix)
if os.path.exists(mask_path):
break
save_mask_path = root_path + 'masks/' + part_dir + basename.split(
'.')[0] + save_seg_map_suffix
masks = np.array(Image.open(mask_path).convert('L')).astype(np.uint8)
if len(np.unique(masks)) == 2 and 1 in np.unique(masks):
print(np.unique(masks))
pass
else:
masks[masks < 128] = 0
masks[masks >= 128] = 1
masks = Image.fromarray(masks)
masks.save(save_mask_path)

0 comments on commit 78e036c

Please sign in to comment.