Skip to content

Commit

Permalink
[Project] add Gamma Task3 dataset project in dev-1.x (#2695)
Browse files Browse the repository at this point in the history
  • Loading branch information
Provable0816 committed Jun 20, 2023
1 parent 4feba31 commit f419f61
Show file tree
Hide file tree
Showing 8 changed files with 398 additions and 1 deletion.
167 changes: 167 additions & 0 deletions projects/medical/2d_image/fundus_photography/gamma3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Glaucoma grAding from Multi-Modality imAges Task3

## Description

This project support **`Glaucoma grAding from Multi-Modality imAges Task3`**, and the dataset used in this project can be downloaded from [here](https://aistudio.baidu.com/aistudio/competition/detail/121/0/datasets).

### Dataset Overview

This regular-challenge dataset was provided by Sun Yat-sen Ophthalmic Center, Sun Yat-sen University, Guangzhou, China. The dataset contains 200 fundus color images: 100 pairs in the training set and 100 pairs in the test set.

### Original Statistic Information

| Dataset name | Anatomical region | Task type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
| ----------------------------------------------------------------------------------- | ----------------- | ------------ | --------------- | ------------ | --------------------- | ---------------------- | ------------ | --------------------------------------------------------------- |
| [GammaTask3](https://aistudio.baidu.com/aistudio/competition/detail/121/0/datasets) | eye | segmentation | fundus photophy | 3 | 100/-/100 | yes/-/- | 2021 | [CC-BY-NC 4.0](https://creativecommons.org/licenses/by-sa/4.0/) |

| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
| background | 100 | 99.02 | - | - | - | - |
| optic disc | 100 | 0.67 | - | - | - | - |
| optic cup | 100 | 0.31 | - | - | - | - |

Note:

- `Pct` means percentage of pixels in this category in all pixels.

### Visualization

![bac](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/fundus_photography/gamma3/gamma3_dataset.png)

## Dataset Citation

```bibtex
@article{fu2018joint,
title={Joint optic disc and cup segmentation based on multi-label deep network and polar transformation},
author={Fu, Huazhu and Cheng, Jun and Xu, Yanwu and Wong, Damon Wing Kee and Liu, Jiang and Cao, Xiaochun},
journal={IEEE transactions on medical imaging},
volume={37},
number={7},
pages={1597--1605},
year={2018},
publisher={IEEE}
}
@article{sevastopolsky2017optic,
title={Optic disc and cup segmentation methods for glaucoma detection with modification of U-Net convolutional neural network},
author={Sevastopolsky, Artem},
journal={Pattern Recognition and Image Analysis},
volume={27},
pages={618--624},
year={2017},
publisher={Springer}
}
```

### Prerequisites

- Python v3.8
- PyTorch v1.10.0
- pillow(PIL) v9.3.0
- scikit-learn(sklearn) v1.2.0
- [MIM](https://github.com/open-mmlab/mim) v0.3.4
- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5

All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `gammm3/` root directory, run the following line to add the current directory to `PYTHONPATH`:

```shell
export PYTHONPATH=`pwd`:$PYTHONPATH
```

### Dataset preparing

- download dataset from [here](https://aistudio.baidu.com/aistudio/competition/detail/121/0/datasets) and decompression data to path `'data/'`.
- run script `"python tools/prepare_dataset.py"` to split dataset and change folder structure as below.
- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set can't be obtained, we generate `train.txt` and `val.txt` from the training set randomly.

```none
mmsegmentation
├── mmseg
├── projects
│ ├── medical
│ │ ├── 2d_image
│ │ │ ├── fundus_photography
│ │ │ │ ├── gamma3
│ │ │ │ │ ├── configs
│ │ │ │ │ ├── datasets
│ │ │ │ │ ├── tools
│ │ │ │ │ ├── data
│ │ │ │ │ │ ├── train.txt
│ │ │ │ │ │ ├── val.txt
│ │ │ │ │ │ ├── images
│ │ │ │ │ │ │ ├── train
│ │ │ │ | │ │ │ ├── xxx.png
│ │ │ │ | │ │ │ ├── ...
│ │ │ │ | │ │ │ └── xxx.png
│ │ │ │ │ │ │ ├── test
│ │ │ │ | │ │ │ ├── yyy.png
│ │ │ │ | │ │ │ ├── ...
│ │ │ │ | │ │ │ └── yyy.png
│ │ │ │ │ │ ├── masks
│ │ │ │ │ │ │ ├── train
│ │ │ │ | │ │ │ ├── xxx.png
│ │ │ │ | │ │ │ ├── ...
│ │ │ │ | │ │ │ └── xxx.png
```

### Divided Dataset Information

***Note: The table information below is divided by ourselves.***

| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
| background | 80 | 99.01 | 20 | 99.07 | - | - |
| optic disc | 80 | 0.68 | 20 | 0.63 | - | - |
| optic cup | 80 | 0.32 | 20 | 0.31 | - | - |

### Training commands

To train models on a single server with one GPU. (default)

```shell
mim train mmseg ./configs/${CONFIG_PATH}
```

### Testing commands

To test models on a single server with one GPU. (default)

```shell
mim test mmseg ./configs/${CONFIG_PATH} --checkpoint ${CHECKPOINT_PATH}
```

<!-- List the results as usually done in other model's README. [Example](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/configs/fcn#results-and-models)
You should claim whether this is based on the pre-trained weights, which are converted from the official release; or it's a reproduced result obtained from retraining the model in this project. -->

## Checklist

- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.

- [x] Finish the code

- [x] Basic docstrings & proper citation

- [ ] Test-time correctness

- [x] A full README

- [ ] Milestone 2: Indicates a successful model implementation.

- [ ] Training-time correctness

- [ ] Milestone 3: Good to be a part of our core package!

- [ ] Type hints and docstrings

- [ ] Unit tests

- [ ] Code polishing

- [ ] Metafile.yml

- [ ] Move your modules into the core package following the codebase's file hierarchy structure.

- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
_base_ = [
'mmseg::_base_/models/fcn_unet_s5-d16.py', './gamma3_512x512.py',
'mmseg::_base_/default_runtime.py',
'mmseg::_base_/schedules/schedule_20k.py'
]
custom_imports = dict(imports='datasets.gamma3_dataset')
img_scale = (512, 512)
data_preprocessor = dict(size=img_scale)
optimizer = dict(lr=0.0001)
optim_wrapper = dict(optimizer=optimizer)
model = dict(
data_preprocessor=data_preprocessor,
decode_head=dict(num_classes=3),
auxiliary_head=None,
test_cfg=dict(mode='whole', _delete_=True))
vis_backends = None
visualizer = dict(vis_backends=vis_backends)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
_base_ = [
'mmseg::_base_/models/fcn_unet_s5-d16.py', './gamma3_512x512.py',
'mmseg::_base_/default_runtime.py',
'mmseg::_base_/schedules/schedule_20k.py'
]
custom_imports = dict(imports='datasets.gamma3_dataset')
img_scale = (512, 512)
data_preprocessor = dict(size=img_scale)
optimizer = dict(lr=0.001)
optim_wrapper = dict(optimizer=optimizer)
model = dict(
data_preprocessor=data_preprocessor,
decode_head=dict(num_classes=3),
auxiliary_head=None,
test_cfg=dict(mode='whole', _delete_=True))
vis_backends = None
visualizer = dict(vis_backends=vis_backends)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
_base_ = [
'mmseg::_base_/models/fcn_unet_s5-d16.py', './gamma3_512x512.py',
'mmseg::_base_/default_runtime.py',
'mmseg::_base_/schedules/schedule_20k.py'
]
custom_imports = dict(imports='datasets.gamma3_dataset')
img_scale = (512, 512)
data_preprocessor = dict(size=img_scale)
optimizer = dict(lr=0.01)
optim_wrapper = dict(optimizer=optimizer)
model = dict(
data_preprocessor=data_preprocessor,
decode_head=dict(num_classes=3),
auxiliary_head=None,
test_cfg=dict(mode='whole', _delete_=True))
vis_backends = None
visualizer = dict(vis_backends=vis_backends)
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
dataset_type = 'Gamma3Dataset'
data_root = 'data/'
img_scale = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', scale=img_scale, keep_ratio=False),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=img_scale, keep_ratio=False),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
train_dataloader = dict(
batch_size=16,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='train.txt',
data_prefix=dict(img_path='images/', seg_map_path='masks/'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='val.txt',
data_prefix=dict(img_path='images/', seg_map_path='masks/'),
pipeline=test_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from mmseg.datasets import BaseSegDataset
from mmseg.registry import DATASETS


@DATASETS.register_module()
class Gamma3Dataset(BaseSegDataset):
"""Gamma3Dataset dataset.
In segmentation map annotation for Gamma3Dataset,
``reduce_zero_label`` is fixed to False. The ``img_suffix``
is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
Args:
img_suffix (str): Suffix of images. Default: '.png'
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
reduce_zero_label (bool): Whether to mark label zero as ignored.
Default to False.
"""
METAINFO = dict(classes=('background', 'disc', 'cup'))

def __init__(self,
img_suffix='.png',
seg_map_suffix='.png',
reduce_zero_label=False,
**kwargs) -> None:
super().__init__(
img_suffix=img_suffix,
seg_map_suffix=seg_map_suffix,
reduce_zero_label=reduce_zero_label,
**kwargs)
Loading

0 comments on commit f419f61

Please sign in to comment.