diff --git a/docs/zh_cn/advanced_guides/add_datasets.md b/docs/zh_cn/advanced_guides/add_datasets.md
index 4ea14934ed2..22fbf3462fc 100644
--- a/docs/zh_cn/advanced_guides/add_datasets.md
+++ b/docs/zh_cn/advanced_guides/add_datasets.md
@@ -1,4 +1,62 @@
-# 新增自定义数据集(待更新)
+# 新增自定义数据集
+
+## 新增自定义数据集
+
+在这里,我们展示如何构建一个新的数据集。
+
+1. 创建一个新文件 `mmseg/datasets/example.py`
+
+ ```python
+ from mmseg.registry import DATASETS
+ from .basesegdataset import BaseSegDataset
+
+
+ @DATASETS.register_module()
+ class ExampleDataset(BaseSegDataset):
+
+ METAINFO = dict(
+ classes=('xxx', 'xxx', ...),
+ palette=[[x, x, x], [x, x, x], ...])
+
+ def __init__(self, aeg1, arg2):
+ pass
+ ```
+
+2. 在 `mmseg/datasets/__init__.py` 中导入模块
+
+ ```python
+ from .example import ExampleDataset
+ ```
+
+3. 通过创建一个新的数据集配置文件 `configs/_base_/datasets/example_dataset.py` 来使用它
+
+ ```python
+ dataset_type = 'ExampleDataset'
+ data_root = 'data/example/'
+ ...
+ ```
+
+4. 在 `mmseg/utils/class_names.py` 中补充数据集元信息
+
+ ```python
+ def example_classes():
+ return [
+ 'xxx', 'xxx',
+ ...
+ ]
+
+ def example_palette():
+ return [
+ [x, x, x], [x, x, x],
+ ...
+ ]
+ dataset_aliases ={
+ 'example': ['example', ...],
+ ...
+ }
+ ```
+
+**注意:** 如果新数据集不满足 mmseg 的要求,则需要在 `tools/dataset_converters/` 中准备一个数据集预处理脚本
## 通过重新组织数据来定制数据集
@@ -26,30 +84,17 @@
一个训练对将由 img_dir/ann_dir 里同样首缀的文件组成。
-如果给定 `split` 参数,只有部分在 img_dir/ann_dir 里的文件会被加载。
-我们可以对被包括在 split 文本里的文件指定前缀。
+有些数据集不会发布测试集或测试集的标注,如果没有测试集的标注,我们就无法在本地进行评估模型,因此我们在配置文件中将验证集设置为默认测试集。
-除此以外,一个 split 文本如下所示:
-
-```none
-xxx
-zzz
-```
+关于如何构建自己的数据集或实现新的数据集类,请参阅[数据集指南](./datasets.md)以获取更多详细信息。
-只有
-
-`data/my_dataset/img_dir/train/xxx{img_suffix}`,
-`data/my_dataset/img_dir/train/zzz{img_suffix}`,
-`data/my_dataset/ann_dir/train/xxx{seg_map_suffix}`,
-`data/my_dataset/ann_dir/train/zzz{seg_map_suffix}` 将被加载。
-
-注意:标注是跟图像同样的形状 (H, W),其中的像素值的范围是 `[0, num_classes - 1]`。
+**注意:** 标注是跟图像同样的形状 (H, W),其中的像素值的范围是 `[0, num_classes - 1]`。
您也可以使用 [pillow](https://pillow.readthedocs.io/en/stable/handbook/concepts.html#palette) 的 `'P'` 模式去创建包含颜色的标注。
## 通过混合数据去定制数据集
MMSegmentation 同样支持混合数据集去训练。
-当前它支持拼接 (concat), 重复 (repeat) 和多图混合 (multi-image mix)数据集。
+当前它支持拼接 (concat), 重复 (repeat) 和多图混合 (multi-image mix) 数据集。
### 重复数据集
@@ -58,79 +103,29 @@ MMSegmentation 同样支持混合数据集去训练。
```python
dataset_A_train = dict(
- type='RepeatDataset',
- times=N,
- dataset=dict( # 这是 Dataset_A 数据集的原始配置
- type='Dataset_A',
- ...
- pipeline=train_pipeline
- )
+ type='RepeatDataset',
+ times=N,
+ dataset=dict( # 这是 Dataset_A 数据集的原始配置
+ type='Dataset_A',
+ ...
+ pipeline=train_pipeline
)
+)
```
### 拼接数据集
-有2种方式去拼接数据集。
-
-1. 如果您想拼接的数据集是同样的类型,但有不同的标注文件,
- 您可以按如下操作去拼接数据集的配置文件:
-
- 1. 您也许可以拼接两个标注文件夹 `ann_dir`
-
- ```python
- dataset_A_train = dict(
- type='Dataset_A',
- img_dir = 'img_dir',
- ann_dir = ['anno_dir_1', 'anno_dir_2'],
- pipeline=train_pipeline
- )
- ```
-
- 2. 您也可以去拼接两个 `split` 文件列表
-
- ```python
- dataset_A_train = dict(
- type='Dataset_A',
- img_dir = 'img_dir',
- ann_dir = 'anno_dir',
- split = ['split_1.txt', 'split_2.txt'],
- pipeline=train_pipeline
- )
- ```
+如果要拼接不同的数据集,可以按如下方式连接数据集配置。
- 3. 您也可以同时拼接 `ann_dir` 文件夹和 `split` 文件列表
-
- ```python
- dataset_A_train = dict(
- type='Dataset_A',
- img_dir = 'img_dir',
- ann_dir = ['anno_dir_1', 'anno_dir_2'],
- split = ['split_1.txt', 'split_2.txt'],
- pipeline=train_pipeline
- )
- ```
-
- 在这样的情况下, `ann_dir_1` 和 `ann_dir_2` 分别对应于 `split_1.txt` 和 `split_2.txt`
-
-2. 如果您想拼接不同的数据集,您可以如下去拼接数据集的配置文件:
-
- ```python
- dataset_A_train = dict()
- dataset_B_train = dict()
-
- data = dict(
- imgs_per_gpu=2,
- workers_per_gpu=2,
- train = [
- dataset_A_train,
- dataset_B_train
- ],
- val = dataset_A_val,
- test = dataset_A_test
- )
- ```
+```python
+dataset_A_train = dict()
+dataset_B_train = dict()
+concatenate_dataset = dict(
+ type='ConcatDataset',
+ datasets=[dataset_A_train, dataset_B_train])
+```
-一个更复杂的例子如下:分别重复 `Dataset_A` 和 `Dataset_B` N 次和 M 次,然后再去拼接重复后的数据集
+下面是一个更复杂的示例,它分别重复 `Dataset_A` 和 `Dataset_B` N 次和 M 次,然后连接重复的数据集。
```python
dataset_A_train = dict(
@@ -159,41 +154,36 @@ dataset_B_train = dict(
pipeline=train_pipeline
)
)
-data = dict(
- imgs_per_gpu=2,
- workers_per_gpu=2,
- train = [
- dataset_A_train,
- dataset_B_train
- ],
- val = dataset_A_val,
- test = dataset_A_test
-)
+train_dataloader = dict(
+ dataset=dict(
+ type='ConcatDataset',
+ datasets=[dataset_A_train, dataset_B_train]))
+
+val_dataloader = dict(dataset=dataset_A_val)
+test_dataloader = dict(dataset=dataset_A_test)
```
+您可以参考 mmengine 的基础数据集[教程](https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/basedataset.html)以了解更多详细信息
+
### 多图混合集
-我们使用 `MultiImageMixDataset` 作为包装(wrapper)去混合多个数据集的图片。
-`MultiImageMixDataset`可以被类似mosaic和mixup的多图混合数据増广使用。
+我们使用 `MultiImageMixDataset` 作为包装(wrapper)去混合多个数据集的图片。
+`MultiImageMixDataset`可以被类似 mosaic 和 mixup 的多图混合数据増广使用。
-`MultiImageMixDataset`与`Mosaic`数据増广一起使用的例子:
+`MultiImageMixDataset` 与 `Mosaic` 数据増广一起使用的例子:
```python
train_pipeline = [
dict(type='RandomMosaic', prob=1),
dict(type='Resize', img_scale=(1024, 512), keep_ratio=True),
dict(type='RandomFlip', prob=0.5),
- dict(type='Normalize', **img_norm_cfg),
- dict(type='DefaultFormatBundle'),
- dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+ dict(type='PackSegInputs')
]
train_dataset = dict(
type='MultiImageMixDataset',
dataset=dict(
- classes=classes,
- palette=palette,
type=dataset_type,
reduce_zero_label=False,
img_dir=data_root + "images/train",
diff --git a/docs/zh_cn/advanced_guides/add_metrics.md b/docs/zh_cn/advanced_guides/add_metrics.md
index 3a371e357e8..0637b447284 100644
--- a/docs/zh_cn/advanced_guides/add_metrics.md
+++ b/docs/zh_cn/advanced_guides/add_metrics.md
@@ -1 +1,81 @@
-# 新增评测指标 (待更新)
+# 新增评测指标
+
+## 使用 MMSegmentation 的源代码进行开发
+
+在这里,我们用 `CustomMetric` 作为例子来展示如何开发一个新的评测指标。
+
+1. 创建一个新文件 `mmseg/evaluation/metrics/custom_metric.py`。
+
+ ```python
+ from typing import List, Sequence
+
+ from mmengine.evaluator import BaseMetric
+
+ from mmseg.registry import METRICS
+
+
+ @METRICS.register_module()
+ class CustomMetric(BaseMetric):
+
+ def __init__(self, arg1, arg2):
+ """
+ The metric first processes each batch of data_samples and predictions,
+ and appends the processed results to the results list. Then it
+ collects all results together from all ranks if distributed training
+ is used. Finally, it computes the metrics of the entire dataset.
+ """
+
+ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
+ pass
+
+ def compute_metrics(self, results: list) -> dict:
+ pass
+
+ def evaluate(self, size: int) -> dict:
+ pass
+ ```
+
+ 在上面的示例中,`CustomMetric` 是 `BaseMetric` 的子类。它有三个方法:`process`,`compute_metrics` 和 `evaluate`。
+
+ - `process()` 处理一批数据样本和预测。处理后的结果需要显示地传给 `self.results` ,将在处理所有数据样本后用于计算指标。更多细节请参考 [MMEngine 文档](https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/design/evaluation.md)
+
+ - `compute_metrics()` 用于从处理后的结果中计算指标。
+
+ - `evaluate()` 是一个接口,用于计算指标并返回结果。它将由 `ValLoop` 或 `TestLoop` 在 `Runner` 中调用。在大多数情况下,您不需要重写此方法,但如果您想做一些额外的工作,可以重写它。
+
+ **注意:** 您可以在[这里](https://github.com/open-mmlab/mmengine/blob/main/mmengine/runner/loops.py#L366) 找到 `Runner` 调用 `evaluate()` 方法的过程。`Runner` 是训练和测试过程的执行器,您可以在[训练引擎文档](./engine.md)中找到有关它的详细信息。
+
+2. 在 `mmseg/evaluation/metrics/__init__.py` 中导入新的指标。
+
+ ```python
+ from .custom_metric import CustomMetric
+ __all__ = ['CustomMetric', ...]
+ ```
+
+3. 在配置文件中设置新的评测指标
+
+ ```python
+ val_evaluator = dict(type='CustomMetric', arg1=xxx, arg2=xxx)
+ test_evaluator = dict(type='CustomMetric', arg1=xxx, arg2=xxx)
+ ```
+
+## 使用发布版本的 MMSegmentation 进行开发
+
+上面的示例展示了如何使用 MMSegmentation 的源代码开发新指标。如果您想使用 MMSegmentation 的发布版本开发新指标,可以按照以下步骤操作。
+
+1. 创建一个新文件 `/Path/to/metrics/custom_metric.py`,实现 `process`,`compute_metrics` 和 `evaluate` 方法,`evaluate` 方法是可选的。
+
+2. 在代码或配置文件中导入新的指标。
+
+ ```python
+ from path.to.metrics import CustomMetric
+ ```
+
+ 或者
+
+ ```python
+ custom_imports = dict(imports=['/Path/to/metrics'], allow_failed_imports=False)
+
+ val_evaluator = dict(type='CustomMetric', arg1=xxx, arg2=xxx)
+ test_evaluator = dict(type='CustomMetric', arg1=xxx, arg2=xxx)
+ ```
diff --git a/docs/zh_cn/advanced_guides/add_models.md b/docs/zh_cn/advanced_guides/add_models.md
index 2f0a5af0d18..e05c07c8bac 100644
--- a/docs/zh_cn/advanced_guides/add_models.md
+++ b/docs/zh_cn/advanced_guides/add_models.md
@@ -166,7 +166,7 @@ loss_decode=dict(type='MyLoss', loss_weight=1.0))
### 添加新的数据预处理器(data preprocessor)
-在 MMSegmentation 1.x 版本中,我们使用 [SegDataPreProcessor](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/models/data_preprocessor.py#L13) 将数据复制到目标设备,并将数据预处理为默认的模型输入格式。这里我们将展示如何开发一个新的数据预处理器。
+在 MMSegmentation 1.x 版本中,我们使用 [SegDataPreProcessor](https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/data_preprocessor.py#L13) 将数据复制到目标设备,并将数据预处理为默认的模型输入格式。这里我们将展示如何开发一个新的数据预处理器。
1. 创建一个新文件 `mmseg/models/my_datapreprocessor.py`。
diff --git a/docs/zh_cn/advanced_guides/contribute_dataset.md b/docs/zh_cn/advanced_guides/contribute_dataset.md
new file mode 100644
index 00000000000..4222de32a6c
--- /dev/null
+++ b/docs/zh_cn/advanced_guides/contribute_dataset.md
@@ -0,0 +1,461 @@
+# 在 mmsegmentation projects 中贡献一个标准格式的数据集
+
+- 在开始您的贡献流程前,请先阅读[《OpenMMLab 贡献代码指南》](https://mmcv.readthedocs.io/zh_CN/latest/community/contributing.html),以详细的了解 OpenMMLab 代码库的代码贡献流程。
+- 该教程以 [Gaofen Image Dataset (GID)](https://www.sciencedirect.com/science/article/pii/S0034425719303414) 高分 2 号卫星所拍摄的遥感图像语义分割数据集作为样例,来演示在 mmsegmentation 中的数据集贡献流程。
+
+## 步骤 1: 配置 mmsegmentation 开发所需必要环境
+
+- 开发所必需的环境安装请参考[中文快速入门指南](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/zh_cn/get_started.md)或[英文 get_started](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/get_started.md)。
+
+- 如果您已安装了最新版的 pytorch、mmcv、mmengine,那么您可以跳过步骤 1 至[步骤 2](<#[步骤-2](#%E6%AD%A5%E9%AA%A4-2%E4%BB%A3%E7%A0%81%E8%B4%A1%E7%8C%AE%E5%89%8D%E7%9A%84%E5%87%86%E5%A4%87%E5%B7%A5%E4%BD%9C)>)。
+
+- **注:** 在此处无需安装 mmsegmentation,只需安装开发 mmsegmentation 所必需的 pytorch、mmcv、mmengine 等即可。
+
+**新建虚拟环境(如已有合适的开发环境,可跳过)**
+
+- 从[官方网站](https://docs.conda.io/en/latest/miniconda.html)下载并安装 Miniconda
+- 创建一个 conda 环境,并激活
+
+```shell
+conda create --name openmmlab python=3.8 -y
+conda activate openmmlab
+```
+
+**安装 pytorch (如环境下已安装 pytorch,可跳过)**
+
+- 参考 [official instructions](https://pytorch.org/get-started/locally/) 安装 **PyTorch**
+
+**使用 mim 安装 mmcv、mmengine**
+
+- 使用 [MIM](https://github.com/open-mmlab/mim) 安装 [MMCV](https://github.com/open-mmlab/mmcv)
+
+```shell
+pip install -U openmim
+mim install mmengine
+mim install "mmcv>=2.0.0"
+```
+
+## 步骤 2:代码贡献前的准备工作
+
+### 2.1 Fork mmsegmentation 仓库
+
+- 通过浏览器打开[mmsegmentation 官方仓库](https://github.com/open-mmlab/mmsegmentation/tree/main)。
+- 登录您的 GitHub 账户,以下步骤均需在 GitHub 登录的情况下进行。
+- Fork mmsegmentation 仓库
+ ![image](https://user-images.githubusercontent.com/50650583/233825567-b8bf273c-38f5-4487-b4c6-75ede1e283ee.png)
+- Fork 之后,mmsegmentation 仓库将会出现在您的个人仓库中。
+
+### 2.2 在您的代码编写软件中 git clone mmsegmentation
+
+这里以 VSCODE 为例
+
+- 打开 VSCODE,新建终端窗口并激活您在[步骤 1 ](#%E6%AD%A5%E9%AA%A4-1-%E9%85%8D%E7%BD%AE-mmsegmentation-%E5%BC%80%E5%8F%91%E6%89%80%E9%9C%80%E5%BF%85%E8%A6%81%E7%8E%AF%E5%A2%83)中所安装的虚拟环境。
+- 在您 GitHub 的个人仓库中找到您 Fork 的 mmsegmentation 仓库,复制其链接。
+ ![image](https://github.com/AI-Tianlong/OpenMMLabCamp/assets/50650583/92ad555b-c5b2-4a7f-a800-ebee1e405ab6)
+- 在终端中执行命令
+ ```bash
+ git clone {您所复制的个人仓库的链接}
+ ```
+ ![image](https://github.com/AI-Tianlong/OpenMMLabCamp/assets/50650583/23ba2636-e66f-4ea5-9077-9dd6b69deb1d)
+ **注:** 如提示以下信息,请在 GitHub 中添加 [SSH 秘钥](https://docs.github.com/en/authentication/connecting-to-github-with-ssh/generating-a-new-ssh-key-and-adding-it-to-the-ssh-agent)
+ ![image](https://github.com/AI-Tianlong/OpenMMLabCamp/assets/50650583/6fcab213-0739-483c-b345-c59656027377)
+- 进入 mmsegmentation 目录(之后的操作均在 mmsegmentation 目录下)。
+ ```bash
+ cd mmsegmentation
+ ```
+- 在终端中执行以下命令,添加官方仓库为上游仓库。
+ ```bash
+ git remote add upstream git@github.com:open-mmlab/mmsegmentation.git
+ ```
+- 使用以下命令检查 remote 是否添加成功。
+ ```bash
+ git remote -v
+ ```
+ ![image](https://github.com/AI-Tianlong/OpenMMLabCamp/assets/50650583/beec7e5e-2b00-4e49-ab38-f0c79e346594)
+
+### 2.3 切换目录至 mmsegmentation 并从源码安装mmsegmentation
+
+在`mmsegmentation`目录下执行`pip install -v -e .`,通过源码构建方式安装 mmsegmentaion 库。
+安装完成后,您将能看到如下图所示的文件树。
+
+
+### 2.4 切换分支为 dev-1.x
+
+正如您在[ mmsegmentation 官网](https://github.com/open-mmlab/mmsegmentation/tree/main)所见,该仓库有许多分支,默认分支`main`为稳定的发行版本,以及用于贡献者进行开发的`dev-1.x`分支。`dev-1.x`分支是贡献者们用来提交创意和 PR 的分支,`dev-1.x`分支的内容会被周期性的合入到`main`分支。
+![image](https://user-images.githubusercontent.com/50650583/233826225-f4b7299d-de23-47db-900d-dfb01ba0efc3.png)
+
+回到 VSCODE 中,在终端执行命令
+
+```bash
+git checkout dev-1.x
+```
+
+### 2.5 创新属于自己的新分支
+
+在基于`dev-1.x`分支下,使用如下命令,创建属于您自己的分支。
+
+```bash
+# git checkout -b 您的GitHubID/您的分支想要实现的功能的名字
+# git checkout -b AI-Tianlong/support_GID_dataset
+git checkout -b {您的GitHubID/您的分支想要实现的功能的名字}
+```
+
+### 2.6 配置 pre-commit
+
+OpenMMLab 仓库对代码质量有着较高的要求,所有提交的 PR 必须要通过代码格式检查。pre-commit 详细配置参阅[配置 pre-commit](https://mmcv.readthedocs.io/zh_CN/latest/community/contributing.html#pre-commit)。
+
+## 步骤 3:在`mmsegmentation/projects`下贡献您的代码
+
+**先对 GID 数据集进行分析**
+
+这里以贡献高分 2 号遥感图像语义分割数据集 GID 为例,GID 数据集是由我国自主研发的高分 2 号卫星所拍摄的光学遥感图像所创建,经图像预处理后共提供了 150 张 6800x7200 像素的 RGB 三通道遥感图像。并提供了两种不同类别数的数据标注,一种是包含 5 类有效物体的 RGB 标签,另一种是包含 15 类有效物体的 RGB 标签。本教程将针对 5 类标签进行数据集贡献流程讲解。
+
+GID 的 5 类有效标签分别为:0-背景-\[0,0,0\](mask 标签值-标签名称-RGB 标签值)、1-建筑-\[255,0,0\]、2-农田-\[0,255,0\]、3-森林-\[0,0,255\]、4-草地-\[255,255,0\]、5-水-\[0,0,255\]。在语义分割任务中,标签是与原图尺寸一致的单通道图像,标签图像中的像素值为真实样本图像中对应像素所包含的物体的类别。GID 数据集提供的是具有 RGB 三通道的彩色标签,为了模型的训练需要将 RGB 标签转换为 mask 标签。并且由于图像尺寸为 6800x7200 像素,对于神经网络的训练来有些过大,所以将每张图像裁切成了没有重叠的 512x512 的图像以便进行训练。
+
+
+### 3.1 在`mmsegmentation/projects`下创建新的项目文件夹
+
+在`mmsegmentation/projects`下创建文件夹`gid_dataset`
+![image](https://user-images.githubusercontent.com/50650583/233829687-8f2b6600-bc9d-48ff-a865-d462af54d55a.png)
+
+### 3.2 贡献您的数据集代码
+
+为了最终能将您在 projects 中贡献的代码更加顺畅的移入核心库中(对代码要求质量更高),非常建议按照核心库的目录来编辑您的数据集文件。
+关于数据集有 4 个必要的文件:
+
+- **1** `mmseg/datasets/gid.py` 定义了数据集的尾缀、CLASSES、PALETTE、reduce_zero_label等
+- **2** `configs/_base_/gid.py` GID 数据集的配置文件,定义了数据集的`dataset_type`(数据集类型,`mmseg/datasets/gid.py`中注册的数据集的类名)、`data_root`(数据集所在的根目录,建议将数据集通过软连接的方式将数据集放至`mmsegmentation/data`)、`train_pipline`(训练的数据流)、`test_pipline`(测试和验证时的数据流)、`img_rations`(多尺度预测时的多尺度配置)、`tta_pipeline`(多尺度预测)、`train_dataloader`(训练集的数据加载器)、`val_dataloader`(验证集的数据加载器)、`test_dataloader`(测试集的数据加载器)、`val_evaluator`(验证集的评估器)、`test_evaluator`(测试集的评估器)。
+- **3** 使用了 GID 数据集的模型训练配置文件
+ 这个是可选的,但是强烈建议您添加。在核心库中,所贡献的数据集需要和参考文献中所提出的结果精度对齐,为了后期将您贡献的代码合并入核心库。如您的算力充足,最好能提供对应的模型配置文件在您贡献的数据集上所验证的结果以及相应的权重文件,并撰写较为详细的README.md文档。[示例参考结果](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/deeplabv3plus#mapillary-vistas-v12)
+ ![image](https://user-images.githubusercontent.com/50650583/233877682-eabe8723-bce9-40e4-a303-08c8385cb6b5.png)
+- **4** 使用如下命令格式: 撰写`docs/zh_cn/user_guides/2_dataset_prepare.md`来添加您的数据集介绍,包括但不限于数据集的下载方式,数据集目录结构、数据集生成等一些必要性的文字性描述和运行命令。以更好地帮助用户能更快的实现数据集的准备工作。
+
+### 3.3 贡献`tools/dataset_converters/gid.py`
+
+由于 GID 数据集是由未经过切分的 6800x7200 图像所构成的数据集,并且没有划分训练集、验证集与测试集。以及其标签为 RGB 彩色标签,需要将标签转换为单通道的 mask label。为了方便训练,首先将 GID 数据集进行裁切和标签转换,并进行数据集划分,构建为 mmsegmentation 所支持的格式。
+
+```python
+# tools/dataset_converters/gid.py
+import argparse
+import glob
+import math
+import os
+import os.path as osp
+from PIL import Image
+
+import mmcv
+import numpy as np
+from mmengine.utils import ProgressBar, mkdir_or_exist
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description='Convert GID dataset to mmsegmentation format')
+ parser.add_argument('dataset_img_path', help='GID images folder path')
+ parser.add_argument('dataset_label_path', help='GID labels folder path')
+ parser.add_argument('--tmp_dir', help='path of the temporary directory')
+ parser.add_argument('-o', '--out_dir', help='output path', default='data/gid')
+ parser.add_argument(
+ '--clip_size',
+ type=int,
+ help='clipped size of image after preparation',
+ default=256)
+ parser.add_argument(
+ '--stride_size',
+ type=int,
+ help='stride of clipping original images',
+ default=256)
+ args = parser.parse_args()
+ return args
+
+GID_COLORMAP = dict(
+ Background=(0, 0, 0), #0-背景-黑色
+ Building=(255, 0, 0), #1-建筑-红色
+ Farmland=(0, 255, 0), #2-农田-绿色
+ Forest=(0, 0, 255), #3-森林-蓝色
+ Meadow=(255, 255, 0),#4-草地-黄色
+ Water=(0, 0, 255)#5-水-蓝色
+)
+palette = list(GID_COLORMAP.values())
+classes = list(GID_COLORMAP.keys())
+
+#############用列表来存一个 RGB 和一个类别的对应################
+def colormap2label(palette):
+ colormap2label_list = np.zeros(256**3, dtype = np.longlong)
+ for i, colormap in enumerate(palette):
+ colormap2label_list[(colormap[0] * 256 + colormap[1])*256+colormap[2]] = i
+ return colormap2label_list
+
+#############给定那个列表,和vis_png然后生成masks_png################
+def label_indices(RGB_label, colormap2label_list):
+ RGB_label = RGB_label.astype('int32')
+ idx = (RGB_label[:, :, 0] * 256 + RGB_label[:, :, 1]) * 256 + RGB_label[:, :, 2]
+ # print(idx.shape)
+ return colormap2label_list[idx]
+
+def RGB2mask(RGB_label, colormap2label_list):
+ # RGB_label = np.array(Image.open(RGB_label).convert('RGB')) #打开RGB_png
+ mask_label = label_indices(RGB_label, colormap2label_list) # .numpy()
+ return mask_label
+
+colormap2label_list = colormap2label(palette)
+
+def clip_big_image(image_path, clip_save_dir, args, to_label=False):
+ """
+ Original image of GID dataset is very large, thus pre-processing
+ of them is adopted. Given fixed clip size and stride size to generate
+ clipped image, the intersection of width and height is determined.
+ For example, given one 6800 x 7200 original image, the clip size is
+ 256 and stride size is 256, thus it would generate 29 x 27 = 783 images
+ whose size are all 256 x 256.
+
+ """
+
+ image = mmcv.imread(image_path, channel_order='rgb')
+ # image = mmcv.bgr2gray(image)
+
+ h, w, c = image.shape
+ clip_size = args.clip_size
+ stride_size = args.stride_size
+
+ num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil(
+ (h - clip_size) /
+ stride_size) * stride_size + clip_size >= h else math.ceil(
+ (h - clip_size) / stride_size) + 1
+ num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil(
+ (w - clip_size) /
+ stride_size) * stride_size + clip_size >= w else math.ceil(
+ (w - clip_size) / stride_size) + 1
+
+ x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
+ xmin = x * clip_size
+ ymin = y * clip_size
+
+ xmin = xmin.ravel()
+ ymin = ymin.ravel()
+ xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size,
+ np.zeros_like(xmin))
+ ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size,
+ np.zeros_like(ymin))
+ boxes = np.stack([
+ xmin + xmin_offset, ymin + ymin_offset,
+ np.minimum(xmin + clip_size, w),
+ np.minimum(ymin + clip_size, h)
+ ], axis=1)
+
+ if to_label:
+ image = RGB2mask(image, colormap2label_list) #这里得改一下
+
+ for count, box in enumerate(boxes):
+ start_x, start_y, end_x, end_y = box
+ clipped_image = image[start_y:end_y,
+ start_x:end_x] if to_label else image[
+ start_y:end_y, start_x:end_x, :]
+ img_name = osp.basename(image_path).replace('.tif', '')
+ img_name = img_name.replace('_label', '')
+ if count % 3 == 0:
+ mmcv.imwrite(
+ clipped_image.astype(np.uint8),
+ osp.join(
+ clip_save_dir.replace('train', 'val'),
+ f'{img_name}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
+ else:
+ mmcv.imwrite(
+ clipped_image.astype(np.uint8),
+ osp.join(
+ clip_save_dir,
+ f'{img_name}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
+ count += 1
+
+def main():
+ args = parse_args()
+
+ """
+ According to this paper: https://ieeexplore.ieee.org/document/9343296/
+ select 15 images contained in GID, , which cover the whole six
+ categories, to generate train set and validation set.
+
+ According to Paper: https://ieeexplore.ieee.org/document/9343296/
+
+ """
+
+ if args.out_dir is None:
+ out_dir = osp.join('data', 'gid')
+ else:
+ out_dir = args.out_dir
+
+ print('Making directories...')
+ mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
+ mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
+ mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
+ mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
+
+ src_path_list = glob.glob(os.path.join(args.dataset_img_path, '*.tif'))
+ print(f'Find {len(src_path_list)} pictures')
+
+ prog_bar = ProgressBar(len(src_path_list))
+
+ dst_img_dir = osp.join(out_dir, 'img_dir', 'train')
+ dst_label_dir = osp.join(out_dir, 'ann_dir', 'train')
+
+ for i, img_path in enumerate(src_path_list):
+ label_path = osp.join(args.dataset_label_path, osp.basename(img_path.replace('.tif', '_label.tif')))
+
+ clip_big_image(img_path, dst_img_dir, args, to_label=False)
+ clip_big_image(label_path, dst_label_dir, args, to_label=True)
+ prog_bar.update()
+
+ print('Done!')
+
+if __name__ == '__main__':
+ main()
+```
+
+### 3.4 贡献`mmseg/datasets/gid.py`
+
+可参考[`projects/mapillary_dataset/mmseg/datasets/mapillary.py`](https://github.com/open-mmlab/mmsegmentation/blob/main/projects/mapillary_dataset/mmseg/datasets/mapillary.py)并在此基础上修改相应变量以适配您的数据集。
+
+```python
+# mmseg/datasets/gid.py
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmseg.datasets.basesegdataset import BaseSegDataset
+from mmseg.registry import DATASETS
+
+# 注册数据集类
+@DATASETS.register_module()
+class GID_Dataset(BaseSegDataset):
+ """Gaofen Image Dataset (GID)
+
+ Dataset paper link:
+ https://www.sciencedirect.com/science/article/pii/S0034425719303414
+ https://x-ytong.github.io/project/GID.html
+
+ GID 6 classes: background(others), built-up, farmland, forest, meadow, water
+
+ In This example, select 10 images from GID dataset as training set,
+ and select 5 images as validation set.
+ The selected images are listed as follows:
+
+ GF2_PMS1__L1A0000647767-MSS1
+ GF2_PMS1__L1A0001064454-MSS1
+ GF2_PMS1__L1A0001348919-MSS1
+ GF2_PMS1__L1A0001680851-MSS1
+ GF2_PMS1__L1A0001680853-MSS1
+ GF2_PMS1__L1A0001680857-MSS1
+ GF2_PMS1__L1A0001757429-MSS1
+ GF2_PMS2__L1A0000607681-MSS2
+ GF2_PMS2__L1A0000635115-MSS2
+ GF2_PMS2__L1A0000658637-MSS2
+ GF2_PMS2__L1A0001206072-MSS2
+ GF2_PMS2__L1A0001471436-MSS2
+ GF2_PMS2__L1A0001642620-MSS2
+ GF2_PMS2__L1A0001787089-MSS2
+ GF2_PMS2__L1A0001838560-MSS2
+
+ The ``img_suffix`` is fixed to '.tif' and ``seg_map_suffix`` is
+ fixed to '.tif' for GID.
+ """
+ METAINFO = dict(
+ classes=('Others', 'Built-up', 'Farmland', 'Forest',
+ 'Meadow', 'Water'),
+
+ palette=[[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 255, 255],
+ [255, 255, 0], [0, 0, 255]])
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=None,
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
+```
+
+### 3.5 贡献使用 GID 的训练 config file
+
+```python
+_base_ = [
+ '../../../configs/_base_/models/deeplabv3plus_r50-d8.py',
+ './_base_/datasets/gid.py',
+ '../../../configs/_base_/default_runtime.py',
+ '../../../configs/_base_/schedules/schedule_240k.py'
+]
+custom_imports = dict(
+ imports=['projects.gid_dataset.mmseg.datasets.gid'])
+
+crop_size = (256, 256)
+data_preprocessor = dict(size=crop_size)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ pretrained='open-mmlab://resnet101_v1c',
+ backbone=dict(depth=101),
+ decode_head=dict(num_classes=6),
+ auxiliary_head=dict(num_classes=6))
+
+```
+
+### 3.6 撰写`docs/zh_cn/user_guides/2_dataset_prepare.md`
+
+**Gaofen Image Dataset (GID)**
+
+- GID 数据集可在[此处](https://x-ytong.github.io/project/Five-Billion-Pixels.html)进行下载。
+- GID 数据集包含 150 张 6800x7200 的大尺寸图像,标签为 RGB 标签。
+- 此处选择 15 张图像生成训练集和验证集,该 15 张图像包含了所有六类信息。所选的图像名称如下:
+
+```None
+ GF2_PMS1__L1A0000647767-MSS1
+ GF2_PMS1__L1A0001064454-MSS1
+ GF2_PMS1__L1A0001348919-MSS1
+ GF2_PMS1__L1A0001680851-MSS1
+ GF2_PMS1__L1A0001680853-MSS1
+ GF2_PMS1__L1A0001680857-MSS1
+ GF2_PMS1__L1A0001757429-MSS1
+ GF2_PMS2__L1A0000607681-MSS2
+ GF2_PMS2__L1A0000635115-MSS2
+ GF2_PMS2__L1A0000658637-MSS2
+ GF2_PMS2__L1A0001206072-MSS2
+ GF2_PMS2__L1A0001471436-MSS2
+ GF2_PMS2__L1A0001642620-MSS2
+ GF2_PMS2__L1A0001787089-MSS2
+ GF2_PMS2__L1A0001838560-MSS2
+```
+
+执行以下命令进行裁切及标签的转换,需要修改为您所存储 15 张图像及标签的路径。
+
+```
+python projects/gid_dataset/tools/dataset_converters/gid.py [15 张图像的路径] [15 张标签的路径]
+```
+
+完成裁切后的 GID 数据结构如下:
+
+```none
+mmsegmentation
+├── mmseg
+├── tools
+├── configs
+├── data
+│ ├── gid
+│ │ ├── ann_dir
+| │ │ │ ├── train
+| │ │ │ ├── val
+│ │ ├── img_dir
+| │ │ │ ├── train
+| │ │ │ ├── val
+
+```
+
+### 3.7 贡献的代码及文档通过`pre-commit`检查
+
+使用命令
+
+```bash
+git add .
+git commit -m "添加描述"
+git push
+```
+
+### 3.8 在 GitHub 中向 mmsegmentation 提交 PR
+
+具体步骤可见[《OpenMMLab 贡献代码指南》](https://mmcv.readthedocs.io/zh_CN/latest/community/contributing.html)
diff --git a/docs/zh_cn/advanced_guides/data_flow.md b/docs/zh_cn/advanced_guides/data_flow.md
index 0716d36d1b4..20dbe07e75d 100644
--- a/docs/zh_cn/advanced_guides/data_flow.md
+++ b/docs/zh_cn/advanced_guides/data_flow.md
@@ -16,7 +16,7 @@ val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
```
-在上图中,红色线表示 [train_step](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/advanced_guides/models.md#train_step) ***([中文链接待更新](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/advanced_guides/models.md#train_step))*** ,在每次训练迭代中,数据加载器(dataloader)从存储中加载图像并传输到数据预处理器(data preprocessor),数据预处理器会将图像放到特定的设备上,并将数据堆叠到批处理中,之后模型接受批处理数据作为输入,最后将模型的输出发送给优化器(optimizer)。蓝色线表示 [val_step](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/advanced_guides/models.md#val_step) 和 [test_step](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/advanced_guides/models.md#test_step) ***([中文链接待更新](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/advanced_guides/models.md#test_step))*** 。这两个过程的数据流除了模型输出与 `train_step` 不同外,其余均和 `train_step` 类似。由于在评估时模型参数会被冻结,因此模型的输出将被传递给 [Evaluator](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/advanced_guides/evaluation.md#ioumetric) ***([中文链接待更新](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/advanced_guides/evaluation.md#ioumetric))***
+在上图中,红色线表示 [train_step](./models.md#train_step),在每次训练迭代中,数据加载器(dataloader)从存储中加载图像并传输到数据预处理器(data preprocessor),数据预处理器会将图像放到特定的设备上,并将数据堆叠到批处理中,之后模型接受批处理数据作为输入,最后将模型的输出发送给优化器(optimizer)。蓝色线表示 [val_step](./models.md#val_step) 和 [test_step](./models.md#test_step)。这两个过程的数据流除了模型输出与 `train_step` 不同外,其余均和 `train_step` 类似。由于在评估时模型参数会被冻结,因此模型的输出将被传递给 [Evaluator](./evaluation.md#ioumetric)。
来计算指标。
## MMSegmentation 中的数据流约定
@@ -28,7 +28,7 @@ test_cfg = dict(type='TestLoop')
数据加载器(DataLoader)是 MMEngine 的训练和测试流程中的一个重要组件。
从概念上讲,它源于 [PyTorch](https://pytorch.org/) 并保持一致。DataLoader 从文件系统加载数据,原始数据通过数据准备流程后被发送给数据预处理器。
-MMSegmentation 在 [PackSegInputs](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/datasets/transforms/formatting.py#L12) 中定义了默认数据格式, 它是 `train_pipeline` 和 `test_pipeline` 的最后一个组件。有关数据转换 `pipeline` 的更多信息,请参阅[数据转换文档](https://mmsegmentation.readthedocs.io/en/dev-1.x/advanced_guides/transforms.html)。 ***([中文链接待更新](https://mmsegmentation.readthedocs.io/zh_CN/dev-1.x/advanced_guides/transforms.html))***
+MMSegmentation 在 [PackSegInputs](https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/datasets/transforms/formatting.py#L12) 中定义了默认数据格式, 它是 `train_pipeline` 和 `test_pipeline` 的最后一个组件。有关数据转换 `pipeline` 的更多信息,请参阅[数据转换文档](./transforms.md)。
在没有任何修改的情况下,PackSegInputs 的返回值通常是一个包含 `inputs` 和 `data_samples` 的 `dict`。以下伪代码展示了 mmseg 中数据加载器输出的数据类型,它是从数据集中获取的一批数据样本,数据加载器将它们打包成一个字典列表。`inputs` 是输入进模型的张量列表,`data_samples` 包含了输入图像的 meta information 和相应的 ground truth。
@@ -39,11 +39,11 @@ dict(
)
```
-**注意:** [SegDataSample](https://github.com/open-mmlab/mmsegmentation/blob/1.x/mmseg/structures/seg_data_sample.py) 是 MMSegmentation 的数据结构接口,用于连接不同组件。`SegDataSample` 实现了抽象数据元素 `mmengine.structures.BaseDataElement`,更多信息请在 [MMEngine](https://github.com/open-mmlab/mmengine) 中参阅 [SegDataSample 文档](https://mmsegmentation.readthedocs.io/zh_CN/1.x/advanced_guides/structures.html)和[数据元素文档](https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/data_element.html)。
+**注意:** [SegDataSample](https://github.com/open-mmlab/mmsegmentation/blob/1.x/mmseg/structures/seg_data_sample.py) 是 MMSegmentation 的数据结构接口,用于连接不同组件。`SegDataSample` 实现了抽象数据元素 `mmengine.structures.BaseDataElement`,更多信息请在 [MMEngine](https://github.com/open-mmlab/mmengine) 中参阅 [SegDataSample 文档](./structures.md)和[数据元素文档](https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/data_element.html)。
### 数据预处理器到模型
-虽然在[上面的图](##数据流概述)中分开绘制了数据预处理器和模型,但数据预处理器是模型的一部分,因此可以在[模型教程](https://mmsegmentation.readthedocs.io/en/dev-1.x/advanced_guides/models.html)中找到数据预处理器章节。 ***([中文链接待更新](https://mmsegmentation.readthedocs.io/zh_CN/dev-1.x/advanced_guides/models.html))***
+虽然在[上面的图](##数据流概述)中分开绘制了数据预处理器和模型,但数据预处理器是模型的一部分,因此可以在[模型教程](./models.md)中找到数据预处理器章节。
数据预处理器的返回值是一个包含 `inputs` 和 `data_samples` 的字典,其中 `inputs` 是批处理图像的 4D 张量,`data_samples` 中添加了一些用于数据预处理的额外元信息。当传递给网络时,字典将被解包为两个值。 以下伪代码展示了数据预处理器的返回值和模型的输入值。
@@ -61,22 +61,22 @@ class Network(BaseSegmentor):
pass
```
-**注意:** 模型的前向传播有 3 种模式,由输入参数 mode 控制,更多信息请参阅[模型教程](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/advanced_guides/models.md)。 ***([中文链接待更新](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/advanced_guides/models.md))***
+**注意:** 模型的前向传播有 3 种模式,由输入参数 mode 控制,更多信息请参阅[模型教程](./models.md)。
### 模型输出
-如[模型教程](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/advanced_guides/models.md#forward) ***([中文链接待更新](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/advanced_guides/models.md#forward))*** 所提到的 3 种前向传播具有 3 种输出。
+如[模型教程](./models.md#forward) ***([中文链接待更新](./models.md#forward))*** 所提到的 3 种前向传播具有 3 种输出。
`train_step` 和 `test_step`(或 `val_step`)分别对应于 `'loss'` 和 `'predict'`。
-在 `test_step` 或 `val_step` 中,推理结果会被传递给 `Evaluator` 。您可以参阅[评估文档](https://mmsegmentation.readthedocs.io/en/dev-1.x/advanced_guides/evaluation.html) ***([中文链接待更新](https://mmsegmentation.readthedocs.io/zh_CN/dev-1.x/advanced_guides/evaluation.html))*** 来获取更多关于 `Evaluator` 的信息。
+在 `test_step` 或 `val_step` 中,推理结果会被传递给 `Evaluator` 。您可以参阅[评估文档](./evaluation.md)来获取更多关于 `Evaluator` 的信息。
-在推理后,MMSegmentation 中的 [BaseSegmentor](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/models/segmentors/base.py#L15) 会对推理结果进行简单的后处理以打包推理结果。神经网络生成的分割 logits,经过 `argmax` 操作后的分割 mask 和 ground truth(如果存在)将被打包到类似 `SegDataSample` 的实例。 [postprocess_result](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/models/segmentors/base.py#L132) 的返回值是一个 **`SegDataSample`的`List`**。下图显示了这些 `SegDataSample` 实例的关键属性。
+在推理后,MMSegmentation 中的 [BaseSegmentor](https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/segmentors/base.py#L15) 会对推理结果进行简单的后处理以打包推理结果。神经网络生成的分割 logits,经过 `argmax` 操作后的分割 mask 和 ground truth(如果存在)将被打包到类似 `SegDataSample` 的实例。 [postprocess_result](https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/segmentors/base.py#L132) 的返回值是一个 **`SegDataSample`的`List`**。下图显示了这些 `SegDataSample` 实例的关键属性。
![SegDataSample](https://user-images.githubusercontent.com/15952744/209912225-ab46a8d9-904a-43cb-8bf1-8bec4938ed29.png)
-与数据预处理器一致,损失函数也是模型的一部分,它是[解码头](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/models/decode_heads/decode_head.py#L142)的属性之一。
+与数据预处理器一致,损失函数也是模型的一部分,它是[解码头](https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/decode_heads/decode_head.py#L142)的属性之一。
-在 MMSegmentation 中,`decode_head` 的 [loss_by_feat](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/models/decode_heads/decode_head.py#L291) 方法是用于计算损失的统一接口。
+在 MMSegmentation 中,`decode_head` 的 [loss_by_feat](https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/decode_heads/decode_head.py#L291) 方法是用于计算损失的统一接口。
参数:
@@ -87,4 +87,4 @@ class Network(BaseSegmentor):
- dict\[str, Tensor\]:一个损失组件的字典
-**注意:** `train_step` 将损失传递进 OptimWrapper 以更新模型中的权重,更多信息请参阅 [train_step](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/advanced_guides/models.md#train_step)。 ***([中文链接待更新](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/advanced_guides/models.md#train_step))***
+**注意:** `train_step` 将损失传递进 OptimWrapper 以更新模型中的权重,更多信息请参阅 [train_step](./models.md#train_step)。
diff --git a/docs/zh_cn/advanced_guides/datasets.md b/docs/zh_cn/advanced_guides/datasets.md
index 546e97f70d6..b45f2d22bb6 100644
--- a/docs/zh_cn/advanced_guides/datasets.md
+++ b/docs/zh_cn/advanced_guides/datasets.md
@@ -1,10 +1,10 @@
# 数据集
-在 MMSegmentation 算法库中, 所有 Dataset 类的功能有两个: 加载[预处理](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/user_guides/2_dataset_prepare.md) 之后的数据集的信息, 和将数据送入[数据集变换流水线](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/datasets/basesegdataset.py#L141) 中, 进行[数据变换操作](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/advanced_guides/transforms.md). 加载的数据集信息包括两类: 元信息 (meta information), 数据集本身的信息, 例如数据集总共的类别, 和它们对应调色盘信息: 数据信息 (data information) 是指每组数据中图片和对应标签的路径. 下文中介绍了 MMSegmentation 1.x 中数据集的常用接口, 和 mmseg 数据集基类中数据信息加载与修改数据集类别的逻辑, 以及数据集与数据变换流水线 (pipeline) 的关系.
+在 MMSegmentation 算法库中, 所有 Dataset 类的功能有两个: 加载[预处理](../user_guides/2_dataset_prepare.md) 之后的数据集的信息, 和将数据送入[数据集变换流水线](https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/datasets/basesegdataset.py#L141) 中, 进行[数据变换操作](./transforms.md). 加载的数据集信息包括两类: 元信息 (meta information), 数据集本身的信息, 例如数据集总共的类别, 和它们对应调色盘信息: 数据信息 (data information) 是指每组数据中图片和对应标签的路径. 下文中介绍了 MMSegmentation 1.x 中数据集的常用接口, 和 mmseg 数据集基类中数据信息加载与修改数据集类别的逻辑, 以及数据集与数据变换流水线 (pipeline) 的关系.
## 常用接口
-以 Cityscapes 为例, 介绍数据集常用接口. 如需运行以下示例, 请在当前工作目录下的 `data` 目录下载并[预处理](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/user_guides/2_dataset_prepare.md#cityscapes) Cityscapes 数据集.
+以 Cityscapes 为例, 介绍数据集常用接口. 如需运行以下示例, 请在当前工作目录下的 `data` 目录下载并[预处理](../user_guides/2_dataset_prepare.md#cityscapes) Cityscapes 数据集.
实例化 Cityscapes 训练数据集:
@@ -96,7 +96,7 @@ print(dataset.metainfo)
'reduce_zero_label': False}
```
-数据集 `__getitem__` 方法的返回值, 是经过数据增强的样本数据的输出, 同样也是一个字典, 包括两个字段, `'inputs'` 字段是当前样本经过数据增强操作的图像, 类型为 torch.Tensor, `'data_samples'` 字段存放的数据类型是 MMSegmentation 1.x 新添加的数据结构 [`Segdatasample`](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/advanced_guides/structures.md), 其中`gt_sem_seg` 字段是经过数据增强的标签数据.
+数据集 `__getitem__` 方法的返回值, 是经过数据增强的样本数据的输出, 同样也是一个字典, 包括两个字段, `'inputs'` 字段是当前样本经过数据增强操作的图像, 类型为 torch.Tensor, `'data_samples'` 字段存放的数据类型是 MMSegmentation 1.x 新添加的数据结构 [`Segdatasample`](./structures.md), 其中`gt_sem_seg` 字段是经过数据增强的标签数据.
```python
print(dataset[0])
@@ -166,13 +166,13 @@ print(dataset[0])
## BaseSegDataset
-由于 MMSegmentation 中的所有数据集的基本功能均包括(1) 加载[数据集预处理](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/user_guides/2_dataset_prepare.md) 之后的数据信息和 (2) 将数据送入数据变换流水线中进行数据变换, 因此在 MMSegmentation 中将其中的共同接口抽象成 [`BaseSegDataset`](https://mmsegmentation.readthedocs.io/en/dev-1.x/api.html?highlight=BaseSegDataset#mmseg.datasets.BaseSegDataset),它继承自 [MMEngine 的 `BaseDataset`](https://github.com/open-mmlab/mmengine/blob/main/docs/en/advanced_tutorials/basedataset.md), 遵循 OpenMMLab 数据集初始化统一流程, 支持高效的内部数据存储格式, 支持数据集拼接、数据集重复采样等功能.
+由于 MMSegmentation 中的所有数据集的基本功能均包括(1) 加载[数据集预处理](../user_guides/2_dataset_prepare.md) 之后的数据信息和 (2) 将数据送入数据变换流水线中进行数据变换, 因此在 MMSegmentation 中将其中的共同接口抽象成 [`BaseSegDataset`](https://mmsegmentation.readthedocs.io/zh_CN/latest/api.html?highlight=BaseSegDataset#mmseg.datasets.BaseSegDataset),它继承自 [MMEngine 的 `BaseDataset`](https://github.com/open-mmlab/mmengine/blob/main/docs/en/advanced_tutorials/basedataset.md), 遵循 OpenMMLab 数据集初始化统一流程, 支持高效的内部数据存储格式, 支持数据集拼接、数据集重复采样等功能.
在 MMSegmentation BaseSegDataset 中重新定义了**数据信息加载方法**(`load_data_list`)和并新增了 `get_label_map` 方法用来**修改数据集的类别信息**.
### 数据信息加载
-数据信息加载的内容是样本数据的图片路径和标签路径, 具体实现在 MMSegmentation 的 BaseSegDataset 的 [`load_data_list`](https://github.com/open-mmlab/mmsegmentation/blob/163277bfe0fa8fefb63ee5137917fafada1b301c/mmseg/datasets/basesegdataset.py#L231) 中.
-主要有两种获取图片和标签的路径方法, 如果当数据集目录按以下目录结构组织, [`load_data_list`](https://github.com/open-mmlab/mmsegmentation/blob/163277bfe0fa8fefb63ee5137917fafada1b301c/mmseg/datasets/basesegdataset.py#L231)) 会根据数据路径和后缀来解析.
+数据信息加载的内容是样本数据的图片路径和标签路径, 具体实现在 MMSegmentation 的 BaseSegDataset 的 [`load_data_list`](https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/datasets/basesegdataset.py#L231) 中.
+主要有两种获取图片和标签的路径方法, 如果当数据集目录按以下目录结构组织, [`load_data_list`](https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/datasets/basesegdataset.py#L231)) 会根据数据路径和后缀来解析.
```
├── data
@@ -322,7 +322,7 @@ print(dataset.metainfo)
'reduce_zero_label': False}
```
-可以看到, 数据集元信息的类别和默认 Cityscapes 不同. 并且, 定义了标签重映射的字段 `label_map` 用来修改每个分割掩膜上的像素的类别索引, 分割标签类别会根据 `label_map`, 将类别重映射, [具体实现](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/datasets/basesegdataset.py#L151):
+可以看到, 数据集元信息的类别和默认 Cityscapes 不同. 并且, 定义了标签重映射的字段 `label_map` 用来修改每个分割掩膜上的像素的类别索引, 分割标签类别会根据 `label_map`, 将类别重映射, [具体实现](https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/datasets/basesegdataset.py#L151):
```python
gt_semantic_seg_copy = gt_semantic_seg.copy()
diff --git a/docs/zh_cn/advanced_guides/engine.md b/docs/zh_cn/advanced_guides/engine.md
index a5746fcec89..79b4c8d2296 100644
--- a/docs/zh_cn/advanced_guides/engine.md
+++ b/docs/zh_cn/advanced_guides/engine.md
@@ -61,21 +61,21 @@ OpenMMLab 将模型训练和测试过程抽象为 `Runner`, 插入钩子可以
- 默认钩子 (default hooks)
-它们实现了训练时所必需的功能, 在配置文件中用 `default_hooks` 定义传给 `Runner`, `Runner` 通过 [`register_default_hooks`](https://github.com/open-mmlab/mmengine/blob/090104df21acd05a8aadae5a0d743a7da3314f6f/mmengine/runner/runner.py#L1780) 方法注册.
+它们实现了训练时所必需的功能, 在配置文件中用 `default_hooks` 定义传给 `Runner`, `Runner` 通过 [`register_default_hooks`](https://github.com/open-mmlab/mmengine/blob/main/mmengine/runner/runner.py#L1780) 方法注册.
钩子有对应的优先级, 优先级越高, 越早被执行器调用. 如果优先级一样, 被调用的顺序和钩子注册的顺序一致.
不建议用户修改默认钩子的优先级, 可以参考 [mmengine hooks 文档](https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/hook.md) 了解钩子优先级的定义.
下面是 MMSegmentation 中所用到的默认钩子:
-| 钩子 | 功能 | 优先级 |
-| :-----------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------: | :---------------: |
-| [IterTimerHook](https://github.com/open-mmlab/mmengine/blob/main/mmengine/hooks/iter_timer_hook.py) | 记录 iteration 花费的时间. | NORMAL (50) |
-| [LoggerHook](https://github.com/open-mmlab/mmengine/blob/main/mmengine/hooks/logger_hook.py) | 从 `Runner` 里不同的组件中收集日志记录, 并将其输出到终端, JSON 文件, tensorboard, wandb 等下游. | BELOW_NORMAL (60) |
-| [ParamSchedulerHook](https://github.com/open-mmlab/mmengine/blob/main/mmengine/hooks/param_scheduler_hook.py) | 更新优化器里面的一些超参数, 例如学习率的动量. | LOW (70) |
-| [CheckpointHook](https://github.com/open-mmlab/mmengine/blob/main/mmengine/hooks/checkpoint_hook.py) | 规律性地保存 checkpoint 文件. | VERY_LOW (90) |
-| [DistSamplerSeedHook](https://github.com/open-mmlab/mmengine/blob/main/mmengine/hooks/sampler_seed_hook.py) | 确保分布式采样器 shuffle 是打开的. | NORMAL (50) |
-| [SegVisualizationHook](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/visualization/local_visualizer.py) | 可视化验证和测试过程里的预测结果. | NORMAL (50) |
+| 钩子 | 功能 | 优先级 |
+| :--------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------: | :---------------: |
+| [IterTimerHook](https://github.com/open-mmlab/mmengine/blob/main/mmengine/hooks/iter_timer_hook.py) | 记录 iteration 花费的时间. | NORMAL (50) |
+| [LoggerHook](https://github.com/open-mmlab/mmengine/blob/main/mmengine/hooks/logger_hook.py) | 从 `Runner` 里不同的组件中收集日志记录, 并将其输出到终端, JSON 文件, tensorboard, wandb 等下游. | BELOW_NORMAL (60) |
+| [ParamSchedulerHook](https://github.com/open-mmlab/mmengine/blob/main/mmengine/hooks/param_scheduler_hook.py) | 更新优化器里面的一些超参数, 例如学习率的动量. | LOW (70) |
+| [CheckpointHook](https://github.com/open-mmlab/mmengine/blob/main/mmengine/hooks/checkpoint_hook.py) | 规律性地保存 checkpoint 文件. | VERY_LOW (90) |
+| [DistSamplerSeedHook](https://github.com/open-mmlab/mmengine/blob/main/mmengine/hooks/sampler_seed_hook.py) | 确保分布式采样器 shuffle 是打开的. | NORMAL (50) |
+| [SegVisualizationHook](https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/visualization/local_visualizer.py) | 可视化验证和测试过程里的预测结果. | NORMAL (50) |
-MMSegmentation 会在 [`defualt_hooks`](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/_base_/schedules/schedule_160k.py#L19-L25) 里面注册一些训练所必需功能的钩子::
+MMSegmentation 会在 [`defualt_hooks`](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/_base_/schedules/schedule_160k.py#L19-L25) 里面注册一些训练所必需功能的钩子::
```python
default_hooks = dict(
@@ -94,6 +94,7 @@ default_hooks = dict(
以 `default_hooks` 里面的 `logger` 和 `checkpoint` 为例, 我们来介绍如何修改 `default_hooks` 中默认的钩子.
(1) 模型保存配置
+
`default_hooks` 使用 `checkpoint` 字段来初始化[模型保存钩子 (CheckpointHook)](https://github.com/open-mmlab/mmengine/blob/main/mmengine/hooks/checkpoint_hook.py#L19).
```python
@@ -104,6 +105,7 @@ checkpoint = dict(type='CheckpointHook', interval=1)
更多相关参数的细节可以参考[这里](https://mmengine.readthedocs.io/zh_CN/latest/api/generated/mmengine.hooks.CheckpointHook.html#checkpointhook).
(2) 日志配置
+
`日志钩子 (LoggerHook)` 被用来收集 `执行器 (Runner)` 里面不同组件的日志信息然后写入终端, JSON 文件, tensorboard 和 wandb 等地方.
```python
@@ -126,7 +128,7 @@ visualizer = dict(
- 自定义钩子 (custom hooks)
-自定义钩子在配置通过 `custom_hooks` 定义, `Runner` 通过 [`register_custom_hooks`](https://github.com/open-mmlab/mmengine/blob/090104df21acd05a8aadae5a0d743a7da3314f6f/mmengine/runner/runner.py#L1852) 方法注册.
+自定义钩子在配置通过 `custom_hooks` 定义, `Runner` 通过 [`register_custom_hooks`](https://github.com/open-mmlab/mmengine/blob/main/mmengine/runner/runner.py#L1820) 方法注册.
自定义钩子优先级需要在配置文件里设置, 如果没有设置, 则会被默认设置为 `NORMAL`. 下面是部分 MMEngine 中实现的自定义钩子:
| 钩子 | 用法 |
@@ -145,7 +147,7 @@ custom_hooks = [
### SegVisualizationHook
-MMSegmentation 实现了 [`SegVisualizationHook`](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/engine/hooks/visualization_hook.py#L17), 用来在验证和测试时可视化预测结果.
+MMSegmentation 实现了 [`SegVisualizationHook`](https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/engine/hooks/visualization_hook.py#L17), 用来在验证和测试时可视化预测结果.
`SegVisualizationHook` 重写了基类 `Hook` 中的 `_after_iter` 方法, 在验证或测试时, 根据指定的迭代次数间隔调用 `visualizer` 的 `add_datasample` 方法绘制语义分割结果, 具体实现如下:
```python
@@ -181,7 +183,7 @@ class SegVisualizationHook(Hook):
```
-关于可视化更多的细节可以查看[这里](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/user_guides/visualization.md).
+关于可视化更多的细节可以查看[这里](../user_guides/visualization.md).
## 优化器
@@ -234,7 +236,7 @@ optim_wrapper = dict(type='AmpOptimWrapper', optimizer=optimizer)
在模型训练中, 如果想在优化器里为不同参数分别设置优化策略, 例如设置不同的学习率、权重衰减等超参数, 可以通过设置配置文件里 `optim_wrapper` 中的 `paramwise_cfg` 来实现.
-下面的配置文件以 [ViT `optim_wrapper`](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/vit/vit_vit-b16-ln_mln_upernet_8xb2-160k_ade20k-512x512.py#L15-L27) 为例介绍 `paramwise_cfg` 参数使用.
+下面的配置文件以 [ViT `optim_wrapper`](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/vit/vit_vit-b16-ln_mln_upernet_8xb2-160k_ade20k-512x512.py#L15-L27) 为例介绍 `paramwise_cfg` 参数使用.
训练时将 `pos_embed`, `mask_token`, `norm` 模块的 weight decay 参数的系数设置成 0.
即: 在训练时, 这些模块的 weight decay 将被变为 `weight_decay * decay_mult`=0.
@@ -257,9 +259,9 @@ optim_wrapper = dict(
### 优化器封装构造器
-默认的优化器封装构造器 [`DefaultOptimWrapperConstructor`](https://github.com/open-mmlab/mmengine/blob/376251961da47ea8254ab808ae5c51e1430f18dc/mmengine/optim/optimizer/default_constructor.py#L19) 根据输入的 `optim_wrapper` 和 `optim_wrapper` 中定义的 `paramwise_cfg` 来构建训练中使用的优化器. 当 [`DefaultOptimWrapperConstructor`](https://github.com/open-mmlab/mmengine/blob/376251961da47ea8254ab808ae5c51e1430f18dc/mmengine/optim/optimizer/default_constructor.py#L19) 功能不能满足需求时, 可以自定义优化器封装构造器来实现超参数的配置.
+默认的优化器封装构造器 [`DefaultOptimWrapperConstructor`](https://github.com/open-mmlab/mmengine/blob/main/mmengine/optim/optimizer/default_constructor.py#L19) 根据输入的 `optim_wrapper` 和 `optim_wrapper` 中定义的 `paramwise_cfg` 来构建训练中使用的优化器. 当 [`DefaultOptimWrapperConstructor`](https://github.com/open-mmlab/mmengine/blob/main/mmengine/optim/optimizer/default_constructor.py#L19) 功能不能满足需求时, 可以自定义优化器封装构造器来实现超参数的配置.
-MMSegmentation 中的实现了 [`LearningRateDecayOptimizerConstructor`](https://github.com/open-mmlab/mmsegmentation/blob/b21df463d47447f33c28d9a4f46136ad64d34a40/mmseg/engine/optimizers/layer_decay_optimizer_constructor.py#L104), 可以对以 ConvNeXt, BEiT 和 MAE 为骨干网络的模型训练时, 骨干网络的模型参数的学习率按照定义的衰减比例(`decay_rate`)逐层递减, 在配置文件中的配置如下:
+MMSegmentation 中的实现了 [`LearningRateDecayOptimizerConstructor`](https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/engine/optimizers/layer_decay_optimizer_constructor.py#L104), 可以对以 ConvNeXt, BEiT 和 MAE 为骨干网络的模型训练时, 骨干网络的模型参数的学习率按照定义的衰减比例(`decay_rate`)逐层递减, 在配置文件中的配置如下:
```python
optim_wrapper = dict(
diff --git a/docs/zh_cn/advanced_guides/models.md b/docs/zh_cn/advanced_guides/models.md
index 408a57863c4..6eb22517a46 100644
--- a/docs/zh_cn/advanced_guides/models.md
+++ b/docs/zh_cn/advanced_guides/models.md
@@ -30,17 +30,17 @@
## 基本接口
-MMSegmentation 封装 `BaseModel` 并实现了 [BaseSegmenter](https://github.com/open-mmlab/mmsegmentation/blob/1.x/mmseg/models/segmentors/base.py#L15) 类,主要提供 `forward`、`train_step`、`val_step` 和 `test_step` 接口。接下来将详细介绍这些接口。
+MMSegmentation 封装 `BaseModel` 并实现了 [BaseSegmentor](https://github.com/open-mmlab/mmsegmentation/blob/1.x/mmseg/models/segmentors/base.py#L15) 类,主要提供 `forward`、`train_step`、`val_step` 和 `test_step` 接口。接下来将详细介绍这些接口。
### forward
-
+
编码器解码器数据流
-
+
级联编码器解码器数据流
@@ -115,7 +115,7 @@ MMSegmentation 封装 `BaseModel` 并实现了 [BaseSegmenter](https://github.co
-Dict\[str, `torch.Tensor`\]:用于记录日志的张量的`字典`。
-
+
train_step 数据流
@@ -132,7 +132,7 @@ MMSegmentation 封装 `BaseModel` 并实现了 [BaseSegmenter](https://github.co
- `list` - 给定数据的预测结果。
-
+
test_step/val_step 数据流
diff --git a/docs/zh_cn/advanced_guides/training_tricks.md b/docs/zh_cn/advanced_guides/training_tricks.md
index a33c0ea9cfd..e5b8e4dae1e 100644
--- a/docs/zh_cn/advanced_guides/training_tricks.md
+++ b/docs/zh_cn/advanced_guides/training_tricks.md
@@ -1,4 +1,4 @@
-# 训练技巧(待更新)
+# 训练技巧
MMSegmentation 支持如下训练技巧:
@@ -9,17 +9,17 @@ MMSegmentation 支持如下训练技巧:
在 MMSegmentation 里面,您也可以在配置文件里添加如下行来让解码头组件的学习率是主干组件的10倍。
```python
-optimizer=dict(
+optim_wrapper=dict(
paramwise_cfg = dict(
custom_keys={
'head': dict(lr_mult=10.)}))
```
-通过这种修改,任何被分组到 `'head'` 的参数的学习率都将乘以10。您也可以参照 [MMCV 文档](https://mmcv.readthedocs.io/en/latest/api.html#mmcv.runner.DefaultOptimizerConstructor) 获取更详细的信息。
+通过这种修改,任何被分组到 `'head'` 的参数的学习率都将乘以10。您也可以参照 [MMEngine 文档](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/optim_wrapper.html#id6) 获取更详细的信息。
## 在线难样本挖掘 (Online Hard Example Mining, OHEM)
-对于训练时采样,我们在 [这里](https://github.com/open-mmlab/mmsegmentation/tree/master/mmseg/core/seg/sampler) 做了像素采样器。
+MMSegmentation 中实现了像素采样器,训练时可以对特定像素进行采样,例如 OHEM(Online Hard Example Mining),可以解决样本不平衡问题,
如下例子是使用 PSPNet 训练并采用 OHEM 策略的配置:
```python
@@ -58,38 +58,17 @@ model=dict(
```python
_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py'
model = dict(
- decode_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
- dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
- auxiliary_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce',loss_weight=1.0),
- dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
- )
+ decode_head=dict(loss_decode=[
+ dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
+ dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)
+ ]),
+ auxiliary_head=dict(loss_decode=[
+ dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
+ dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)
+ ]),
+)
```
通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`。
-注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。
-
-## 在损失函数中忽略特定的 label 类别
-
-默认设置 `avg_non_ignore=False`, 即每个像素都用来计算损失函数。尽管其中的一些像素属于需要被忽略的类别。
-
-对于训练时损失函数的计算,我们目前支持使用 `avg_non_ignore` 和 `ignore_index` 来忽略 label 特定的类别。 这样损失函数将只在非忽略类别像素中求平均值,会获得更好的表现。这里是[相关 PR](https://github.com/open-mmlab/mmsegmentation/pull/1409)。以 `unet` 使用 `Cityscapes` 数据集训练为例,
-在计算损失函数时,忽略 label 为0的背景,并且仅在不被忽略的像素上计算均值。配置文件写为:
-
-```python
-_base_ = './fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py'
-model = dict(
- decode_head=dict(
- ignore_index=0,
- loss_decode=dict(
- type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True),
- auxiliary_head=dict(
- ignore_index=0,
- loss_decode=dict(
- type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, avg_non_ignore=True)),
- ))
-```
-
-通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`。
-
-注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。
+注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在计算图里。
diff --git a/docs/zh_cn/get_started.md b/docs/zh_cn/get_started.md
index 38e93e9cb4e..ca375f3d9d6 100644
--- a/docs/zh_cn/get_started.md
+++ b/docs/zh_cn/get_started.md
@@ -4,7 +4,7 @@
本教程中,我们将会演示如何使用 PyTorch 准备环境。
-MMSegmentation 可以在 Linux, Windows 和 macOS 系统上运行,并且需要安装 Python 3.6+, CUDA 9.2+ 和 PyTorch 1.5+
+MMSegmentation 可以在 Linux, Windows 和 macOS 系统上运行,并且需要安装 Python 3.7+, CUDA 10.2+ 和 PyTorch 1.8+
**注意:**
如果您已经安装了 PyTorch, 可以跳过该部分,直接到[下一小节](##安装)。否则,您可以按照以下步骤操作。
@@ -43,7 +43,7 @@ conda install pytorch torchvision cpuonly -c pytorch
```shell
pip install -U openmim
mim install mmengine
-mim install "mmcv>=2.0.0rc1"
+mim install "mmcv>=2.0.0"
```
**步骤 1.** 安装 MMSegmentation
@@ -51,7 +51,7 @@ mim install "mmcv>=2.0.0rc1"
情况 a: 如果您想立刻开发和运行 mmsegmentation,您可通过源码安装:
```shell
-git clone -b dev-1.x https://github.com/open-mmlab/mmsegmentation.git
+git clone -b main https://github.com/open-mmlab/mmsegmentation.git
cd mmsegmentation
pip install -v -e .
# '-v' 表示详细模式,更多的输出
@@ -62,7 +62,7 @@ pip install -v -e .
情况 b: 如果您把 mmsegmentation 作为依赖库或者第三方库,可以通过 pip 安装:
```shell
-pip install "mmsegmentation>=1.0.0rc0"
+pip install "mmsegmentation>=1.0.0"
```
### 验证是否安装成功
@@ -87,8 +87,7 @@ python demo/image_demo.py demo/demo.png configs/pspnet/pspnet_r50-d8_4xb2-40k_ci
您将在当前文件夹中看到一个新图像 `result.jpg`,其中所有目标都覆盖了分割 mask
-选项 (b). 如果您通过 pip 安装 mmsegmentation, 打开您的 python
-解释器,复制粘贴以下代码:
+选项 (b). 如果您通过 pip 安装 mmsegmentation, 打开您的 python 解释器,复制粘贴以下代码:
```python
from mmseg.apis import inference_model, init_model, show_result_pyplot
@@ -111,8 +110,8 @@ show_result_pyplot(model, img, result, show=True, out_file='result.jpg', opacity
# 在一段视频上测试并可视化分割结果
video = mmcv.VideoReader('video.mp4')
for frame in video:
- result = inference_segmentor(model, frame)
- show_result_pyplot(model, result, wait_time=1)
+ result = inference_model(model, frame)
+ show_result_pyplot(model, frame, result, wait_time=1)
```
您可以修改上面的代码来测试单个图像或视频,这两个选项都可以验证安装是否成功。
@@ -137,15 +136,15 @@ MMCV 包含 C++ 和 CUDA 扩展,因此与 PyTorch 的依赖方式比较复杂
为了使用 pip 而不是 MIM 安装 MMCV, 请参考 [MMCV 安装指南](https://mmcv.readthedocs.io/en/latest/get_started/installation.html). 这需要手动指定一个基于 PyTorch 版本及其 CUDA 版本的 find-url.
-例如,以下命令可为 PyTorch 1.10.x and CUDA 11.3 安装 mmcv==2.0.0rc1
+例如,以下命令可为 PyTorch 1.10.x and CUDA 11.3 安装 mmcv==2.0.0
```shell
-pip install mmcv==2.0.0rc1 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.10/index.html
+pip install mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.10/index.html
```
#### 在仅有 CPU 的平台安装
-MMSegmentation 可以在仅有 CPU 的版本上运行。在 CPU 模式,您可以训练(需要 MMCV-Lite 版本 >= 2.0.0rc0),测试和推理模型。
+MMSegmentation 可以在仅有 CPU 的版本上运行。在 CPU 模式,您可以训练(需要 MMCV 版本 >= 2.0.0),测试和推理模型。
#### 在 Google Colab 上安装
@@ -156,7 +155,7 @@ MMSegmentation 可以在仅有 CPU 的版本上运行。在 CPU 模式,您可
```shell
!pip3 install openmim
!mim install mmengine
-!mim install "mmcv>=2.0.0rc1"
+!mim install "mmcv>=2.0.0"
```
**Step 2.** 通过源码安装 MMSegmentation
@@ -164,7 +163,7 @@ MMSegmentation 可以在仅有 CPU 的版本上运行。在 CPU 模式,您可
```shell
!git clone https://github.com/open-mmlab/mmsegmentation.git
%cd mmsegmentation
-!git checkout dev-1.x
+!git checkout main
!pip install -e .
```
@@ -173,7 +172,7 @@ MMSegmentation 可以在仅有 CPU 的版本上运行。在 CPU 模式,您可
```python
import mmseg
print(mmseg.__version__)
-# 示例输出: 1.0.0rc0
+# 示例输出: 1.0.0
```
**注意:**
@@ -195,6 +194,16 @@ docker build -t mmsegmentation docker/
docker run --gpus all --shm-size=8g -it -v {DATA_DIR}:/mmsegmentation/data mmsegmentation
```
+### 可选依赖
+
+#### 安装 GDAL
+
+[GDAL](https://gdal.org/) 是一个用于栅格和矢量地理空间数据格式的转换库。安装 GDAL 可以读取复杂格式和极大的遥感图像。
+
+```shell
+conda install GDAL
+```
+
## 问题解答
如果您在安装过程中遇到了其他问题,请第一时间查阅 [FAQ](notes/faq.md) 文件。如果没有找到答案,您也可以在 GitHub 上提出 [issue](https://github.com/open-mmlab/mmsegmentation/issues/new/choose)
diff --git a/docs/zh_cn/imgs/qq_group_qrcode.jpg b/docs/zh_cn/imgs/qq_group_qrcode.jpg
deleted file mode 100644
index 417347449fe..00000000000
Binary files a/docs/zh_cn/imgs/qq_group_qrcode.jpg and /dev/null differ
diff --git a/docs/zh_cn/imgs/seggroup_qrcode.jpg b/docs/zh_cn/imgs/seggroup_qrcode.jpg
deleted file mode 100644
index 9684582ee1c..00000000000
Binary files a/docs/zh_cn/imgs/seggroup_qrcode.jpg and /dev/null differ
diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst
index e66c178689d..ce5e49977dc 100644
--- a/docs/zh_cn/index.rst
+++ b/docs/zh_cn/index.rst
@@ -23,7 +23,7 @@
:maxdepth: 1
:caption: 迁移指引
- migration.md
+ migration/index.rst
.. toctree::
:caption: 接口文档(英文)
diff --git a/docs/zh_cn/migration/interface.md b/docs/zh_cn/migration/interface.md
index cd16d2cbc68..42f91bf50ac 100644
--- a/docs/zh_cn/migration/interface.md
+++ b/docs/zh_cn/migration/interface.md
@@ -2,7 +2,7 @@
## 引言
-本指南介绍了 MMSegmentation 0.x 和 MMSegmentation1.x 在行为和 API 方面的基本区别,以及这些如何都与您的迁移过程相关。
+本指南介绍了 MMSegmentation 0.x 和 MMSegmentation1.x 在表现和 API 方面的基本区别,以及这些与迁移过程的关系。
## 新的依赖
@@ -12,11 +12,11 @@ MMSegmentation 1.x 依赖于一些新的软件包,您可以准备一个新的
1. [MMEngine](https://github.com/open-mmlab/mmengine):MMEngine 是 OpenMMLab 2.0 架构的核心,我们将许多与计算机视觉无关的内容从 MMCV 拆分到 MMEngine 中。
-2. [MMCV](https://github.com/open-mmlab/mmcv):OpenMMLab 的计算机视觉包。这不是一个新的依赖,但您需要将其升级到 **2.0.0rc1** 以上的版本。
+2. [MMCV](https://github.com/open-mmlab/mmcv):OpenMMLab 的计算机视觉包。这不是一个新的依赖,但您需要将其升级到 **2.0.0** 或以上的版本。
-3. [MMClassification](https://github.com/open-mmlab/mmclassification)(可选):OpenMMLab 的图像分类工具箱和基准。这不是一个新的依赖,但您需要将其升级到 **1.0.0rc0** 以上的版本。
+3. [MMClassification](https://github.com/open-mmlab/mmclassification)(可选):OpenMMLab 的图像分类工具箱和基准。这不是一个新的依赖,但您需要将其升级到 **1.0.0rc6** 版本。
-4. [MMDetection](https://github.com/open-mmlab/mmdetection)(可选): OpenMMLab 的目标检测工具箱和基准。这不是一个新的依赖,但您需要将其升级到 **3.0.0rc0** 以上的版本。
+4. [MMDetection](https://github.com/open-mmlab/mmdetection)(可选): OpenMMLab 的目标检测工具箱和基准。这不是一个新的依赖,但您需要将其升级到 **3.0.0** 或以上的版本。
## 启动训练
@@ -46,7 +46,7 @@ OpenMMLab 2.0 的主要改进是发布了 MMEngine,它为启动训练任务的
--resume='auto' |
-培训练期间是否不评估检查点 |
+训练期间是否不评估检查点 |
--no-validate |
--cfg-options val_cfg=None val_dataloader=None val_evaluator=None |
@@ -102,11 +102,11 @@ OpenMMLab 2.0 的主要改进是发布了 MMEngine,它为启动训练任务的
- `mean`(Sequence,可选):R、G、B 通道的像素平均值。默认为 None。
-- `std`(Sequence,可选):R、G、B通道的像素标准差。默认为 None。
+- `std`(Sequence,可选):R、G、B 通道的像素标准差。默认为 None。
- `size`(Sequence,可选):固定的填充大小。
-- `size_divisor`(int,可选):填充大小的除法因子。
+- `size_divisor`(int,可选):填充图像可以被当前值整除。
- `seg_pad_val`(float,可选):分割图的填充值。默认值:255。
@@ -154,14 +154,14 @@ train_dataloader = dict(
batch_size=4,
num_workers=4,
dataset=dict(...),
- sampler=dict(type='DefaultSampler', shuffle=True) # necessary
+ sampler=dict(type='DefaultSampler', shuffle=True) # 必须
)
val_dataloader = dict(
batch_size=4,
num_workers=4,
dataset=dict(...),
- sampler=dict(type='DefaultSampler', shuffle=False) # necessary
+ sampler=dict(type='DefaultSampler', shuffle=False) # 必须
)
test_dataloader = val_dataloader
@@ -417,10 +417,10 @@ runner = dict(type='IterBasedRunner', max_iters=20000)
```python
-# The `val_interval` is the original `evaluation.interval`.
+# `val_interval` 是旧版本的 `evaluation.interval`。
train_cfg = dict(type='IterBasedTrainLoop', max_iters=20000, val_interval=2000)
-val_cfg = dict(type='ValLoop') # Use the default validation loop.
-test_cfg = dict(type='TestLoop') # Use the default test loop.
+val_cfg = dict(type='ValLoop') # 使用默认的验证循环。
+test_cfg = dict(type='TestLoop') # 使用默认的测试循环。
```
|
@@ -438,22 +438,22 @@ test_cfg = dict(type='TestLoop') # Use the default test loop.
```python
default_hooks = dict(
- # record the time of every iterations.
+ # 记录每次迭代的时间。
timer=dict(type='IterTimerHook'),
- # print log every 50 iterations.
+ # 每50次迭代打印一次日志。
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
- # enable the parameter scheduler.
+ # 启用参数调度程序。
param_scheduler=dict(type='ParamSchedulerHook'),
- # save checkpoint every 2000 iterations.
+ # 每2000次迭代保存一次检查点。
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
- # set sampler seed in distributed environment.
+ # 在分布式环境中设置采样器种子。
sampler_seed=dict(type='DistSamplerSeedHook'),
- # validation results visualization.
+ # 验证结果可视化。
visualization=dict(type='SegVisualizationHook'))
```
@@ -505,13 +505,13 @@ visualizer = dict(
```python
env_cfg = dict(
- # whether to enable cudnn benchmark
+ # 是否启用 cudnn_benchmark
cudnn_benchmark=False,
- # set multi process parameters
+ # 设置多进程参数
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- # set distributed parameters
+ # 设置分布式参数
dist_cfg=dict(backend='nccl'),
)
```
diff --git a/docs/zh_cn/migration/package.md b/docs/zh_cn/migration/package.md
index d8d2245bed0..19e5f18c9c4 100644
--- a/docs/zh_cn/migration/package.md
+++ b/docs/zh_cn/migration/package.md
@@ -1,6 +1,6 @@
-#包结构更改
+# 包结构更改
-本节包含您对 MMSeg 0.x 和 1.x 之间的变化感到好奇的内容。
+本节包含您对 MMSeg 0.x 和 1.x 之间的变化可能感到好奇的内容。
@@ -49,7 +49,7 @@
## `mmseg.ops`
-`ops` 包包含 `encoding` 和 `wrappers`,它们被移到了 `mmseg.models.utils` 中。
+`ops` 包含 `encoding` 和 `wrappers`,它们被移到了 `mmseg.models.utils` 中。
## 增加的包
@@ -110,4 +110,4 @@ OpenMMLab 2.0 将 `BaseDataset` 定义为数据集的函数和接口,MMSegment
### `mmseg.models`
-`models` 没有太大变化,只是从以前的 `mmseg.ops` 中添加了 `encoding` 和 `wrappers`
+`models` 没有太大变化,只是从以前的 `mmseg.ops` 添加了 `encoding` 和 `wrappers`
diff --git a/docs/zh_cn/notes/faq.md b/docs/zh_cn/notes/faq.md
index 09fde025fde..bf3b3417802 100644
--- a/docs/zh_cn/notes/faq.md
+++ b/docs/zh_cn/notes/faq.md
@@ -1,8 +1,122 @@
-# 常见问题解答(FAQ)(待更新)
+# 常见问题解答(FAQ)
-我们在这里列出了使用时的一些常见问题及其相应的解决方案。 如果您发现有一些问题被遗漏,请随时提 PR 丰富这个列表。 如果您无法在此获得帮助,请使用 [issue模板](https://github.com/open-mmlab/mmsegmentation/blob/master/.github/ISSUE_TEMPLATE/error-report.md/)创建问题,但是请在模板中填写所有必填信息,这有助于我们更快定位问题。
+我们在这里列出了使用时的一些常见问题及其相应的解决方案。 如果您发现有一些问题被遗漏,请随时提 PR 丰富这个列表。 如果您无法在此获得帮助,请使用 [issue 模板](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/.github/ISSUE_TEMPLATE/error-report.md/)创建问题,但是请在模板中填写所有必填信息,这有助于我们更快定位问题。
+
+## 安装
+
+兼容的 MMSegmentation 和 MMCV 版本如下。请安装正确版本的 MMCV 以避免安装问题。
+
+| MMSegmentation version | MMCV version | MMEngine version | MMClassification (optional) version | MMDetection (optional) version |
+| :--------------------: | :----------------------------: | :---------------: | :---------------------------------: | :----------------------------: |
+| dev-1.x branch | mmcv >= 2.0.0 | MMEngine >= 0.7.4 | mmpretrain>=1.0.0rc7 | mmdet >= 3.0.0 |
+| main branch | mmcv >= 2.0.0 | MMEngine >= 0.7.4 | mmpretrain>=1.0.0rc7 | mmdet >= 3.0.0 |
+| 1.1.2 | mmcv >= 2.0.0 | MMEngine >= 0.7.4 | mmpretrain>=1.0.0rc7 | mmdet >= 3.0.0 |
+| 1.1.1 | mmcv >= 2.0.0 | MMEngine >= 0.7.4 | mmpretrain>=1.0.0rc7 | mmdet >= 3.0.0 |
+| 1.1.0 | mmcv >= 2.0.0 | MMEngine >= 0.7.4 | mmpretrain>=1.0.0rc7 | mmdet >= 3.0.0 |
+| 1.0.0 | mmcv >= 2.0.0rc4 | MMEngine >= 0.7.1 | mmcls==1.0.0rc6 | mmdet >= 3.0.0 |
+| 1.0.0rc6 | mmcv >= 2.0.0rc4 | MMEngine >= 0.5.0 | mmcls>=1.0.0rc0 | mmdet >= 3.0.0rc6 |
+| 1.0.0rc5 | mmcv >= 2.0.0rc4 | MMEngine >= 0.2.0 | mmcls>=1.0.0rc0 | mmdet>=3.0.0rc6 |
+| 1.0.0rc4 | mmcv == 2.0.0rc3 | MMEngine >= 0.1.0 | mmcls>=1.0.0rc0 | mmdet>=3.0.0rc4, \<=3.0.0rc5 |
+| 1.0.0rc3 | mmcv == 2.0.0rc3 | MMEngine >= 0.1.0 | mmcls>=1.0.0rc0 | mmdet>=3.0.0rc4, \<=3.0.0rc5 |
+| 1.0.0rc2 | mmcv == 2.0.0rc3 | MMEngine >= 0.1.0 | mmcls>=1.0.0rc0 | mmdet>=3.0.0rc4, \<=3.0.0rc5 |
+| 1.0.0rc1 | mmcv >= 2.0.0rc1, \<=2.0.0rc3> | MMEngine >= 0.1.0 | mmcls>=1.0.0rc0 | Not required |
+| 1.0.0rc0 | mmcv >= 2.0.0rc1, \<=2.0.0rc3> | MMEngine >= 0.1.0 | mmcls>=1.0.0rc0 | Not required |
+
+如果您已经安装了版本不合适的 mmcv,请先运行`pip uninstall mmcv`卸载已安装的 mmcv,如您先前安装的为 mmcv-full(存在于 OpenMMLab 1.x),请运行`pip uninstall mmcv-full`进行卸载。
+
+- 如出现 "No module named 'mmcv'"
+ 1. 使用`pip uninstall mmcv`卸载环境中现有的 mmcv
+ 2. 按照[安装说明](../get_started.md)安装对应的 mmcv
## 如何获知模型训练时需要的显卡数量
-- 看模型的config文件的命名。可以参考[学习配置文件](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/tutorials/config.md)中的`配置文件命名风格`部分。比如,对于名字为`segformer_mit-b0_8x1_1024x1024_160k_cityscapes.py`的config文件,`8x1`代表训练其对应的模型需要的卡数为8,每张卡中的batch size为1。
-- 看模型的log文件。点开该模型的log文件,并在其中搜索`nGPU`,在`nGPU`后的数字个数即训练时所需的卡数。比如,在log文件中搜索`nGPU`得到`nGPU 0,1,2,3,4,5,6,7`的记录,则说明训练该模型需要使用八张卡。
+- 看模型的 config 文件命名。可以参考[了解配置文件](../user_guides/1_config.md)中的`配置文件命名风格`部分。比如,对于名字为`segformer_mit-b0_8xb1-160k_cityscapes-1024x1024.py`的 config 文件,`8xb1`代表训练其对应的模型需要的卡数为 8,每张卡中的 batch size 为 1。
+- 看模型的 log 文件。点开该模型的 log 文件,并在其中搜索`nGPU`,在`nGPU`后的数字个数即训练时所需的卡数。比如,在 log 文件中搜索`nGPU`得到`nGPU 0,1,2,3,4,5,6,7`的记录,则说明训练该模型需要使用八张卡。
+
+## auxiliary head 是什么
+
+简单来说,这是一个提高准确率的深度监督技术。在训练阶段,`decode_head`用于输出语义分割的结果,`auxiliary_head` 只是增加了一个辅助损失,其产生的分割结果对你的模型结果没有影响,仅在在训练中起作用。您可以阅读这篇[论文](https://arxiv.org/pdf/1612.01105.pdf)了解更多信息。
+
+## 运行测试脚本时如何输出绘制分割掩膜的图像
+
+在测试脚本中,我们提供了`--out`参数来控制是否输出保存预测的分割掩膜图像。您可以运行以下命令输出测试结果:
+
+```shell
+python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} --out ${OUTPUT_DIR}
+```
+
+更多用例细节可查阅[文档](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/user_guides/4_train_test.md#%E6%B5%8B%E8%AF%95%E5%B9%B6%E4%BF%9D%E5%AD%98%E5%88%86%E5%89%B2%E7%BB%93%E6%9E%9C),[PR #2712](https://github.com/open-mmlab/mmsegmentation/pull/2712) 以及[迁移文档](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/migration/interface.md#%E6%B5%8B%E8%AF%95%E5%90%AF%E5%8A%A8)了解相关说明。
+
+## 如何处理二值分割任务?
+
+MMSegmentation 使用 `num_classes` 和 `out_channels` 来控制模型最后一层 `self.conv_seg` 的输出。更多细节可以参考 [这里](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/models/decode_heads/decode_head.py)。
+
+`num_classes` 应该和数据集本身类别个数一致,当是二值分割时,数据集只有前景和背景两类,所以 `num_classes` 为 2. `out_channels` 控制模型最后一层的输出的通道数,通常和 `num_classes` 相等,但当二值分割时候,可以有两种处理方法, 分别是:
+
+- 设置 `out_channels=2`,在训练时以 Cross Entropy Loss 作为损失函数,在推理时使用 `F.softmax()` 归一化 logits 值,然后通过 `argmax()` 得到每个像素的预测结果。
+
+- 设置 `out_channels=1`,在训练时以 Binary Cross Entropy Loss 作为损失函数,在推理时使用 `F.sigmoid()` 和 `threshold` 得到预测结果,`threshold` 默认为 0.3。
+
+对于实现上述两种计算二值分割的方法,需要在 `decode_head` 和 `auxiliary_head` 的配置里修改。下面是对样例 [pspnet_unet_s5-d16.py](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/_base_/models/pspnet_unet_s5-d16.py) 做出的对应修改。
+
+- (1) `num_classes=2`, `out_channels=2` 并在 `CrossEntropyLoss` 里面设置 `use_sigmoid=False`。
+
+```python
+decode_head=dict(
+ type='PSPHead',
+ in_channels=64,
+ in_index=4,
+ num_classes=2,
+ out_channels=2,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=128,
+ in_index=3,
+ num_classes=2,
+ out_channels=2,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+```
+
+- (2) `num_classes=2`, `out_channels=1` 并在 `CrossEntropyLoss` 里面设置 `use_sigmoid=True`.
+
+```python
+decode_head=dict(
+ type='PSPHead',
+ in_channels=64,
+ in_index=4,
+ num_classes=2,
+ out_channels=1,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
+auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=128,
+ in_index=3,
+ num_classes=2,
+ out_channels=1,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
+```
+
+## `reduce_zero_label` 的作用
+
+数据集中 `reduce_zero_label` 参数类型为布尔类型,默认为 False,它的功能是为了忽略数据集 label 0。具体做法是将 label 0 改为 255,其余 label 相应编号减 1,同时 decode head 里将 255 设为 ignore index,即不参与 loss 计算。
+以下是 `reduce_zero_label` 具体实现逻辑:
+
+```python
+if self.reduce_zero_label:
+ # avoid using underflow conversion
+ gt_semantic_seg[gt_semantic_seg == 0] = 255
+ gt_semantic_seg = gt_semantic_seg - 1
+ gt_semantic_seg[gt_semantic_seg == 254] = 255
+```
+
+关于您的数据集是否需要使用 reduce_zero_label,有以下两类情况:
+
+- 例如在 [Potsdam](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/user_guides/2_dataset_prepare.md#isprs-potsdam) 数据集上,有 0-不透水面、1-建筑、2-低矮植被、3-树、4-汽车、5-杂乱,六类。但该数据集提供了两种 RGB 标签,一种为图像边缘处有黑色像素的标签,另一种是没有黑色边缘的标签。对于有黑色边缘的标签,在 [dataset_converters.py](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/tools/dataset_converters/potsdam.py)中,其将黑色边缘转换为 label 0,其余标签分别为 1-不透水面、2-建筑、3-低矮植被、4-树、5-汽车、6-杂乱,那么此时,就应该在数据集 [potsdam.py](https://github.com/open-mmlab/mmsegmentation/blob/ff95416c3b5ce8d62b9289f743531398efce534f/mmseg/datasets/potsdam.py#L23) 中将`reduce_zero_label=True`。如果使用的是没有黑色边缘的标签,那么 mask label 中只有 0-5,此时就应该使`reduce_zero_label=False`。需要结合您的实际情况来使用。
+- 例如在第 0 类为 background 类别的数据集上,如果您最终是需要将背景和您的其余类别分开时,是不需要使用`reduce_zero_label`的,此时在数据集中应该将其设置为`reduce_zero_label=False`
+
+**注意:** 使用 `reduce_zero_label` 请确认数据集原始类别个数,如果只有两类,需要关闭 `reduce_zero_label` 即设置 `reduce_zero_label=False`。
diff --git a/docs/zh_cn/overview.md b/docs/zh_cn/overview.md
index 7dce105a813..ed147956d08 100644
--- a/docs/zh_cn/overview.md
+++ b/docs/zh_cn/overview.md
@@ -42,7 +42,7 @@ MMSeg 主要包含了 apis, structures, datasets, models, engine, evaluation 和
以下是详细步骤,将带您一步步学习如何使用 MMSegmentation :
-1. 有关安装说明,请参阅 [开始你的第一步](getting_started.md)。
+1. 有关安装说明,请参阅 [开始你的第一步](get_started.md)。
2. 对于初学者来说,MMSegmentation 是开始语义分割之旅的最好选择,因为这里实现了许多 SOTA 模型以及经典的模型 [model](model_zoo.md) 。另外,将各类组件和高级 API 結合使用,可以更便捷的执行分割任务。关于 MMSegmentation 的基本用法,请参考下面的教程:
@@ -62,8 +62,8 @@ MMSeg 主要包含了 apis, structures, datasets, models, engine, evaluation 和
4. MMSegmentation 也为用户自定义和一些前沿的研究提供了教程,请参考下面的教程来建立你自己的分割项目:
- [添加新的模型](advanced_guides/add_models.md)
- - [添加新的数据集](advanced_guides/add_dataset.md)
- - [添加新的 transform](advanced_guides/add_transform.md)
+ - [添加新的数据集](advanced_guides/add_datasets.md)
+ - [添加新的 transform](advanced_guides/add_transforms.md)
- [自定义 runtime](advanced_guides/customize_runtime.md)
5. 如果您更熟悉 MMSegmentation v0.x , 以下是 MMSegmentation v0.x 迁移到 v1.x 的文档
diff --git a/docs/zh_cn/user_guides/2_dataset_prepare.md b/docs/zh_cn/user_guides/2_dataset_prepare.md
index c9c3606977d..5532624bef4 100644
--- a/docs/zh_cn/user_guides/2_dataset_prepare.md
+++ b/docs/zh_cn/user_guides/2_dataset_prepare.md
@@ -1,3 +1,750 @@
-## 准备数据集(待更新)
+# 教程2:准备数据集
-中文版文档支持中,请先阅读[英文版本](../../en/user_guides/2_dataset_prepare.md)
+我们建议将数据集根目录符号链接到 `$MMSEGMENTATION/data`。
+如果您的目录结构不同,您可能需要更改配置文件中相应的路径。
+对于中国境内的用户,我们也推荐通过开源数据平台 [OpenDataLab](https://opendatalab.com/) 来下载dsdl标准数据,以获得更好的下载和使用体验,这里有一个下载dsdl数据集并进行训练的案例[DSDLReadme](../../../configs/dsdl/README.md),欢迎尝试。
+
+```none
+mmsegmentation
+├── mmseg
+├── tools
+├── configs
+├── data
+│ ├── cityscapes
+│ │ ├── leftImg8bit
+│ │ │ ├── train
+│ │ │ ├── val
+│ │ ├── gtFine
+│ │ │ ├── train
+│ │ │ ├── val
+│ ├── VOCdevkit
+│ │ ├── VOC2012
+│ │ │ ├── JPEGImages
+│ │ │ ├── SegmentationClass
+│ │ │ ├── ImageSets
+│ │ │ │ ├── Segmentation
+│ │ ├── VOC2010
+│ │ │ ├── JPEGImages
+│ │ │ ├── SegmentationClassContext
+│ │ │ ├── ImageSets
+│ │ │ │ ├── SegmentationContext
+│ │ │ │ │ ├── train.txt
+│ │ │ │ │ ├── val.txt
+│ │ │ ├── trainval_merged.json
+│ │ ├── VOCaug
+│ │ │ ├── dataset
+│ │ │ │ ├── cls
+│ ├── ade
+│ │ ├── ADEChallengeData2016
+│ │ │ ├── annotations
+│ │ │ │ ├── training
+│ │ │ │ ├── validation
+│ │ │ ├── images
+│ │ │ │ ├── training
+│ │ │ │ ├── validation
+│ ├── coco_stuff10k
+│ │ ├── images
+│ │ │ ├── train2014
+│ │ │ ├── test2014
+│ │ ├── annotations
+│ │ │ ├── train2014
+│ │ │ ├── test2014
+│ │ ├── imagesLists
+│ │ │ ├── train.txt
+│ │ │ ├── test.txt
+│ │ │ ├── all.txt
+│ ├── coco_stuff164k
+│ │ ├── images
+│ │ │ ├── train2017
+│ │ │ ├── val2017
+│ │ ├── annotations
+│ │ │ ├── train2017
+│ │ │ ├── val2017
+│ ├── CHASE_DB1
+│ │ ├── images
+│ │ │ ├── training
+│ │ │ ├── validation
+│ │ ├── annotations
+│ │ │ ├── training
+│ │ │ ├── validation
+│ ├── DRIVE
+│ │ ├── images
+│ │ │ ├── training
+│ │ │ ├── validation
+│ │ ├── annotations
+│ │ │ ├── training
+│ │ │ ├── validation
+│ ├── HRF
+│ │ ├── images
+│ │ │ ├── training
+│ │ │ ├── validation
+│ │ ├── annotations
+│ │ │ ├── training
+│ │ │ ├── validation
+│ ├── STARE
+│ │ ├── images
+│ │ │ ├── training
+│ │ │ ├── validation
+│ │ ├── annotations
+│ │ │ ├── training
+│ │ │ ├── validation
+| ├── dark_zurich
+| │ ├── gps
+| │ │ ├── val
+| │ │ └── val_ref
+| │ ├── gt
+| │ │ └── val
+| │ ├── LICENSE.txt
+| │ ├── lists_file_names
+| │ │ ├── val_filenames.txt
+| │ │ └── val_ref_filenames.txt
+| │ ├── README.md
+| │ └── rgb_anon
+| │ | ├── val
+| │ | └── val_ref
+| ├── NighttimeDrivingTest
+| | ├── gtCoarse_daytime_trainvaltest
+| | │ └── test
+| | │ └── night
+| | └── leftImg8bit
+| | | └── test
+| | | └── night
+│ ├── loveDA
+│ │ ├── img_dir
+│ │ │ ├── train
+│ │ │ ├── val
+│ │ │ ├── test
+│ │ ├── ann_dir
+│ │ │ ├── train
+│ │ │ ├── val
+│ ├── potsdam
+│ │ ├── img_dir
+│ │ │ ├── train
+│ │ │ ├── val
+│ │ ├── ann_dir
+│ │ │ ├── train
+│ │ │ ├── val
+│ ├── vaihingen
+│ │ ├── img_dir
+│ │ │ ├── train
+│ │ │ ├── val
+│ │ ├── ann_dir
+│ │ │ ├── train
+│ │ │ ├── val
+│ ├── iSAID
+│ │ ├── img_dir
+│ │ │ ├── train
+│ │ │ ├── val
+│ │ │ ├── test
+│ │ ├── ann_dir
+│ │ │ ├── train
+│ │ │ ├── val
+│ ├── synapse
+│ │ ├── img_dir
+│ │ │ ├── train
+│ │ │ ├── val
+│ │ ├── ann_dir
+│ │ │ ├── train
+│ │ │ ├── val
+│ ├── REFUGE
+│ │ ├── images
+│ │ │ ├── training
+│ │ │ ├── validation
+│ │ │ ├── test
+│ │ ├── annotations
+│ │ │ ├── training
+│ │ │ ├── validation
+│ │ │ ├── test
+│ ├── mapillary
+│ │ ├── training
+│ │ │ ├── images
+│ │ │ ├── v1.2
+| │ │ │ ├── instances
+| │ │ │ ├── labels
+| │ │ │ └── panoptic
+│ │ │ ├── v2.0
+| │ │ │ ├── instances
+| │ │ │ ├── labels
+| │ │ │ ├── panoptic
+| │ │ │ └── polygons
+│ │ ├── validation
+│ │ │ ├── images
+| │ │ ├── v1.2
+| │ │ │ ├── instances
+| │ │ │ ├── labels
+| │ │ │ └── panoptic
+│ │ │ ├── v2.0
+| │ │ │ ├── instances
+| │ │ │ ├── labels
+| │ │ │ ├── panoptic
+| │ │ │ └── polygons
+│ ├── bdd100k
+│ │ ├── images
+│ │ │ └── 10k
+| │ │ │ ├── test
+| │ │ │ ├── train
+| │ │ │ └── val
+│ │ └── labels
+│ │ │ └── sem_seg
+| │ │ │ ├── colormaps
+| │ │ │ │ ├──train
+| │ │ │ │ └──val
+| │ │ │ ├── masks
+| │ │ │ │ ├──train
+| │ │ │ │ └──val
+| │ │ │ ├── polygons
+| │ │ │ │ ├──sem_seg_train.json
+| │ │ │ │ └──sem_seg_val.json
+| │ │ │ └── rles
+| │ │ │ │ ├──sem_seg_train.json
+| │ │ │ │ └──sem_seg_val.json
+│ ├── nyu
+│ │ ├── images
+│ │ │ ├── train
+│ │ │ ├── test
+│ │ ├── annotations
+│ │ │ ├── train
+│ │ │ ├── test
+```
+
+## 用 MIM 下载数据集
+
+通过使用 [OpenXLab](https://openxlab.org.cn/datasets),您可以直接下载开源数据集。通过平台的搜索功能,您可以快速轻松地找到他们正在寻找的数据集。使用平台上的格式化数据集,您可以高效地跨数据集执行任务。
+
+如果您使用 MIM 下载,请确保版本大于 v0.3.8。您可以使用以下命令进行更新、安装、登录和数据集下载:
+
+```shell
+# upgrade your MIM
+pip install -U openmim
+
+# install OpenXLab CLI tools
+pip install -U openxlab
+# log in OpenXLab
+openxlab login
+
+# download ADE20K by MIM
+mim download mmsegmentation --dataset ade20k
+```
+
+## Cityscapes
+
+Cityscapes [官方网站](https://www.cityscapes-dataset.com/)可以下载 Cityscapes 数据集,按照官网要求注册并登陆后,数据可以在[这里](https://www.cityscapes-dataset.com/downloads/)找到。
+
+按照惯例,`**labelTrainIds.png` 用于 cityscapes 训练。
+我们提供了一个基于 [cityscapesscripts](https://github.com/mcordts/cityscapesScripts) 的[脚本](https://github.com/open-mmlab/mmsegmentation/blob/1.x/tools/dataset_converters/cityscapes.py)用于生成 `**labelTrainIds.png`。
+
+```shell
+# --nproc 表示 8 个转换进程,也可以省略。
+python tools/dataset_converters/cityscapes.py data/cityscapes --nproc 8
+```
+
+## Pascal VOC
+
+Pascal VOC 2012 可从[此处](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar)下载。
+此外,Pascal VOC 数据集的最新工作通常利用额外的增强数据,可以在[这里](http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz)找到。
+
+如果您想使用增强的 VOC 数据集,请运行以下命令将增强数据的标注转换为正确的格式。
+
+```shell
+# --nproc 表示 8 个转换进程,也可以省略。
+python tools/dataset_converters/voc_aug.py data/VOCdevkit data/VOCdevkit/VOCaug --nproc 8
+```
+
+请参考[拼接数据集文档](../advanced_guides/add_datasets.md#拼接数据集)及 [voc_aug 配置示例](../../../configs/_base_/datasets/pascal_voc12_aug.py)以详细了解如何将它们拼接并合并训练。
+
+## ADE20K
+
+ADE20K 的训练和验证集可以从这个[链接](http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip)下载。
+如果需要下载测试数据集,可以在[官网](http://host.robots.ox.ac.uk/)注册后,下载[测试集](http://host.robots.ox.ac.uk:8080/eval/downloads/VOC2010test.tar)。
+
+## Pascal Context
+
+Pascal Context 的训练和验证集可以从[此处](http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar)下载。注册后,您也可以从[此处](http://host.robots.ox.ac.uk:8080/eval/downloads/VOC2010test.tar)下载测试集。
+
+从原始数据集中抽出部分数据作为验证集,您可以从[此处](https://codalabuser.blob.core.windows.net/public/trainval_merged.json)下载 trainval_merged.json 文件。
+
+请先安装 [Detail](https://github.com/zhanghang1989/detail-api) 工具然后运行以下命令将标注转换为正确的格式。
+
+```shell
+python tools/dataset_converters/pascal_context.py data/VOCdevkit data/VOCdevkit/VOC2010/trainval_merged.json
+```
+
+## COCO Stuff 10k
+
+数据可以通过 wget 在[这里](http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/cocostuff-10k-v1.1.zip)下载。
+
+对于 COCO Stuff 10k 数据集,请运行以下命令下载并转换数据集。
+
+```shell
+# 下载
+mkdir coco_stuff10k && cd coco_stuff10k
+wget http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/cocostuff-10k-v1.1.zip
+
+# 解压
+unzip cocostuff-10k-v1.1.zip
+
+# --nproc 表示 8 个转换进程,也可以省略。
+python tools/dataset_converters/coco_stuff10k.py /path/to/coco_stuff10k --nproc 8
+```
+
+按照惯例,`/path/to/coco_stuff164k/annotations/*2014/*_labelTrainIds.png` 中的 mask 标注用于 COCO Stuff 10k 的训练和测试。
+
+## COCO Stuff 164k
+
+对于 COCO Stuff 164k 数据集,请运行以下命令下载并转换增强的数据集。
+
+```shell
+# 下载
+mkdir coco_stuff164k && cd coco_stuff164k
+wget http://images.cocodataset.org/zips/train2017.zip
+wget http://images.cocodataset.org/zips/val2017.zip
+wget http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip
+
+# 解压
+unzip train2017.zip -d images/
+unzip val2017.zip -d images/
+unzip stuffthingmaps_trainval2017.zip -d annotations/
+
+# --nproc 表示 8 个转换进程,也可以省略。
+python tools/dataset_converters/coco_stuff164k.py /path/to/coco_stuff164k --nproc 8
+```
+
+按照惯例,`/path/to/coco_stuff164k/annotations/*2017/*_labelTrainIds.png` 中的 mask 标注用于 COCO Stuff 164k 的训练和测试。
+
+此数据集的详细信息可在[此处](https://github.com/nightrome/cocostuff#downloads)找到。
+
+## CHASE DB1
+
+CHASE DB1 的训练和验证集可以从[此处](https://staffnet.kingston.ac.uk/~ku15565/CHASE_DB1/assets/CHASEDB1.zip)下载。
+
+请运行以下命令,准备 CHASE DB1 数据集:
+
+```shell
+python tools/dataset_converters/chase_db1.py /path/to/CHASEDB1.zip
+```
+
+该脚本将自动调整数据集目录结构,使其满足 MMSegmentation 数据集加载要求。
+
+## DRIVE
+
+按照[官网](https://drive.grand-challenge.org/)要求,注册并登陆后,便可以下载 DRIVE 的训练和验证数据集。
+
+要将 DRIVE 数据集转换为 MMSegmentation 的格式,请运行以下命令:
+
+```shell
+python tools/dataset_converters/drive.py /path/to/training.zip /path/to/test.zip
+```
+
+该脚本将自动调整数据集目录结构,使其满足 MMSegmentation 数据集加载要求。
+
+## HRF
+
+请下载 [health.zip](https://www5.cs.fau.de/fileadmin/research/datasets/fundus-images/healthy.zip)、[glaucoma.zip](https://www5.cs.fau.de/fileadmin/research/datasets/fundus-images/glaucoma.zip)、[diabetic_retinopathy.zip](https://www5.cs.fau.de/fileadmin/research/datasets/fundus-images/diabetic_retinopathy.zip)、[healthy_manualsegm.zip](https://www5.cs.fau.de/fileadmin/research/datasets/fundus-images/healthy_manualsegm.zip)、[glaucoma_manualsegm.zip](https://www5.cs.fau.de/fileadmin/research/datasets/fundus-images/glaucoma_manualsegm.zip) 和 [diabetic_retinopathy_manualsegm.zip](https://www5.cs.fau.de/fileadmin/research/datasets/fundus-images/diabetic_retinopathy_manualsegm.zip),无需解压,可以直接运行以下命令,准备 HRF 数据集:
+
+```shell
+python tools/dataset_converters/hrf.py /path/to/healthy.zip /path/to/healthy_manualsegm.zip /path/to/glaucoma.zip /path/to/glaucoma_manualsegm.zip /path/to/diabetic_retinopathy.zip /path/to/diabetic_retinopathy_manualsegm.zip
+```
+
+该脚本将自动调整数据集目录结构,使其满足 MMSegmentation 数据集加载要求。
+
+## STARE
+
+请下载 [stare images.tar](http://cecas.clemson.edu/~ahoover/stare/probing/stare-images.tar)、[labels-ah.tar](http://cecas.clemson.edu/~ahoover/stare/probing/labels-ah.tar) 和 [labels-vk.tar](http://cecas.clemson.edu/~ahoover/stare/probing/labels-vk.tar),无需解压,可以直接运行以下命令,准备 STARE 数据集:
+
+```shell
+python tools/dataset_converters/stare.py /path/to/stare-images.tar /path/to/labels-ah.tar /path/to/labels-vk.tar
+```
+
+该脚本将自动调整数据集目录结构,使其满足 MMSegmentation 数据集加载要求。
+
+## Dark Zurich
+
+由于我们只支持在此数据集上的模型测试,因此您只需要下载并解压[验证数据集](https://data.vision.ee.ethz.ch/csakarid/shared/GCMA_UIoU/Dark_Zurich_val_anon.zip)。
+
+## Nighttime Driving
+
+由于我们只支持在此数据集上的模型测试,因此您只需要下载并解压[验证数据集](http://data.vision.ee.ethz.ch/daid/NighttimeDriving/NighttimeDrivingTest.zip)。
+
+## LoveDA
+
+数据可以从[此处](https://drive.google.com/drive/folders/1ibYV0qwn4yuuh068Rnc-w4tPi0U0c-ti?usp=sharing)下载 LaveDA 数据集。
+
+或者可以从 [zenodo](https://zenodo.org/record/5706578#.YZvN7SYRXdF) 下载。下载后,无需解压,直接运行以下命令:
+
+```shell
+# 下载 Train.zip
+wget https://zenodo.org/record/5706578/files/Train.zip
+# 下载 Val.zip
+wget https://zenodo.org/record/5706578/files/Val.zip
+# 下载 Test.zip
+wget https://zenodo.org/record/5706578/files/Test.zip
+```
+
+请对于 LoveDA 数据集,请运行以下命令调整数据集目录。
+
+```shell
+python tools/dataset_converters/loveda.py /path/to/loveDA
+```
+
+可将模型对 LoveDA 的测试集的预测结果上传至到数据集[测试服务器](https://codalab.lisn.upsaclay.fr/competitions/421),查看评测结果。
+
+有关 LoveDA 的更多详细信息,可查看[此处](https://github.com/Junjue-Wang/LoveDA).
+
+## ISPRS Potsdam
+
+[Potsdam](https://www.isprs.org/education/benchmarks/UrbanSemLab/2d-sem-label-potsdam.aspx) 城市语义分割数据集用于 2D 语义分割竞赛 —— Potsdam。
+
+数据集可以在竞赛[主页](https://www.isprs.org/education/benchmarks/UrbanSemLab/default.aspx)上请求获得。
+这里也提供了[BaiduNetdisk](https://pan.baidu.com/s/1K-cLVZnd1X7d8c26FQ-nGg?pwd=mseg),提取码:mseg、 [Google Drive](https://drive.google.com/drive/folders/1w3EJuyUGet6_qmLwGAWZ9vw5ogeG0zLz?usp=sharing)以及[OpenDataLab](https://opendatalab.com/ISPRS_Potsdam/download)。
+实验中需要下载 '2_Ortho_RGB.zip' 和 '5_Labels_all_noBoundary.zip'。
+
+对于 Potsdam 数据集,请运行以下命令调整数据集目录。
+
+```shell
+python tools/dataset_converters/potsdam.py /path/to/potsdam
+```
+
+在我们的默认设置中,将生成 3456 张图像用于训练和 2016 张图像用于验证。
+
+## ISPRS Vaihingen
+
+[Vaihingen](https://www.isprs.org/education/benchmarks/UrbanSemLab/2d-sem-label-vaihingen.aspx) 城市语义分割数据集用于 2D 语义分割竞赛 —— Vaihingen。
+
+数据集可以在竞赛[主页](https://www.isprs.org/education/benchmarks/UrbanSemLab/default.aspx)上请求获得。
+这里也提供了[BaiduNetdisk](https://pan.baidu.com/s/109D3WLrLafsuYtLeerLiiA?pwd=mseg),提取码:mseg 、 [Google Drive](https://drive.google.com/drive/folders/1w3NhvLVA2myVZqOn2pbiDXngNC7NTP_t?usp=sharing)。
+实验中需要下载 'ISPRS_semantic_labeling_Vaihingen.zip' 和 'ISPRS_semantic_labeling_Vaihingen_ground_truth_eroded_COMPLETE.zip'。
+
+对于 Vaihingen 数据集,请运行以下命令调整数据集目录。
+
+```shell
+python tools/dataset_converters/vaihingen.py /path/to/vaihingen
+```
+
+在我们的默认设置(`clip_size`=512, `stride_size`=256)中,将生成 344 张图像用于训练和 398 张图像用于验证。
+
+## iSAID
+
+iSAID 数据集可从 [DOTA-v1.0](https://captain-whu.github.io/DOTA/dataset.html) 下载训练/验证/测试数据集的图像数据,
+
+并从 [iSAID](https://captain-whu.github.io/iSAID/dataset.html)下载训练/验证数据集的标注数据。
+
+该数据集是航空图像实例分割和语义分割任务的大规模数据集。
+
+下载 iSAID 数据集后,您可能需要按照以下结构进行数据集准备。
+
+```none
+├── data
+│ ├── iSAID
+│ │ ├── train
+│ │ │ ├── images
+│ │ │ │ ├── part1.zip
+│ │ │ │ ├── part2.zip
+│ │ │ │ ├── part3.zip
+│ │ │ ├── Semantic_masks
+│ │ │ │ ├── images.zip
+│ │ ├── val
+│ │ │ ├── images
+│ │ │ │ ├── part1.zip
+│ │ │ ├── Semantic_masks
+│ │ │ │ ├── images.zip
+│ │ ├── test
+│ │ │ ├── images
+│ │ │ │ ├── part1.zip
+│ │ │ │ ├── part2.zip
+```
+
+```shell
+python tools/dataset_converters/isaid.py /path/to/iSAID
+```
+
+在我们的默认设置(`patch_width`=896, `patch_height`=896, `overlap_area`=384)中,将生成 33978 张图像用于训练和 11644 张图像用于验证。
+
+## LIP(Look Into Person) dataset
+
+该数据集可以从[此页面](https://lip.sysuhcp.com/overview.php)下载。
+
+请运行以下命令来解压数据集。
+
+```shell
+unzip LIP.zip
+cd LIP
+unzip TrainVal_images.zip
+unzip TrainVal_parsing_annotations.zip
+cd TrainVal_parsing_annotations
+unzip TrainVal_parsing_annotations.zip
+mv train_segmentations ../
+mv val_segmentations ../
+cd ..
+```
+
+LIP 数据集的内容包括:
+
+```none
+├── data
+│ ├── LIP
+│ │ ├── train_images
+│ │ │ ├── 1000_1234574.jpg
+│ │ │ ├── ...
+│ │ ├── train_segmentations
+│ │ │ ├── 1000_1234574.png
+│ │ │ ├── ...
+│ │ ├── val_images
+│ │ │ ├── 100034_483681.jpg
+│ │ │ ├── ...
+│ │ ├── val_segmentations
+│ │ │ ├── 100034_483681.png
+│ │ │ ├── ...
+```
+
+## Synapse dataset
+
+此数据集可以从[此页面](https://www.synapse.org/#!Synapse:syn3193805/wiki/)下载。
+
+遵循 [TransUNet](https://arxiv.org/abs/2102.04306) 的数据准备设定,将原始训练集(30 次扫描)拆分为新的训练集(18 次扫描)和验证集(12 次扫描)。请运行以下命令来准备数据集。
+
+```shell
+unzip RawData.zip
+cd ./RawData/Training
+```
+
+然后创建 `train.txt` 和 `val.txt` 以拆分数据集。
+
+根据 TransUnet,以下是数据集的划分。
+
+train.txt
+
+```none
+img0005.nii.gz
+img0006.nii.gz
+img0007.nii.gz
+img0009.nii.gz
+img0010.nii.gz
+img0021.nii.gz
+img0023.nii.gz
+img0024.nii.gz
+img0026.nii.gz
+img0027.nii.gz
+img0028.nii.gz
+img0030.nii.gz
+img0031.nii.gz
+img0033.nii.gz
+img0034.nii.gz
+img0037.nii.gz
+img0039.nii.gz
+img0040.nii.gz
+```
+
+val.txt
+
+```none
+img0008.nii.gz
+img0022.nii.gz
+img0038.nii.gz
+img0036.nii.gz
+img0032.nii.gz
+img0002.nii.gz
+img0029.nii.gz
+img0003.nii.gz
+img0001.nii.gz
+img0004.nii.gz
+img0025.nii.gz
+img0035.nii.gz
+```
+
+synapse 数据集的内容包括:
+
+```none
+├── Training
+│ ├── img
+│ │ ├── img0001.nii.gz
+│ │ ├── img0002.nii.gz
+│ │ ├── ...
+│ ├── label
+│ │ ├── label0001.nii.gz
+│ │ ├── label0002.nii.gz
+│ │ ├── ...
+│ ├── train.txt
+│ ├── val.txt
+```
+
+然后,使用此命令转换 synapse 数据集。
+
+```shell
+python tools/dataset_converters/synapse.py --dataset-path /path/to/synapse
+```
+
+注意,MMSegmentation 的默认评估指标(例如 mean dice value)是在 2D 切片图像上计算的,这与 [TransUNet](https://arxiv.org/abs/2102.04306) 等一些论文中的 3D 扫描结果是不同的。
+
+## REFUGE
+
+在 [REFUGE Challenge](https://refuge.grand-challenge.org) 官网上注册并下载 [REFUGE 数据集](https://refuge.grand-challenge.org/REFUGE2Download)。
+
+然后,解压 `REFUGE2.zip`,原始数据集的内容包括:
+
+```none
+├── REFUGE2
+│ ├── REFUGE2
+│ │ ├── Annotation-Training400.zip
+│ │ ├── REFUGE-Test400.zip
+│ │ ├── REFUGE-Test-GT.zip
+│ │ ├── REFUGE-Training400.zip
+│ │ ├── REFUGE-Validation400.zip
+│ │ ├── REFUGE-Validation400-GT.zip
+│ ├── __MACOSX
+```
+
+请运行以下命令转换 REFUGE 数据集:
+
+```shell
+python tools/convert_datasets/refuge.py --raw_data_root=/path/to/refuge/REFUGE2/REFUGE2
+```
+
+脚本会将目录结构转换如下:
+
+```none
+│ ├── REFUGE
+│ │ ├── images
+│ │ │ ├── training
+│ │ │ ├── validation
+│ │ │ ├── test
+│ │ ├── annotations
+│ │ │ ├── training
+│ │ │ ├── validation
+│ │ │ ├── test
+```
+
+包含 400 张用于训练的图像、400 张用于验证的图像和 400 张用于测试的图像,这与 REFUGE 2018 数据集相同。
+
+## Mapillary Vistas Datasets
+
+- Mapillary Vistas [官方网站](https://www.mapillary.com/dataset/vistas) 可以下载 Mapillary Vistas 数据集,按照官网要求注册并登陆后,数据可以在[这里](https://www.mapillary.com/dataset/vistas)找到。
+
+- Mapillary Vistas 数据集使用 8-bit with color-palette 来存储标签。不需要进行转换操作。
+
+- 假设您已将数据集 zip 文件放在 `mmsegmentation/data/mapillary` 中
+
+- 请运行以下命令来解压数据集。
+
+ ```bash
+ cd data/mapillary
+ unzip An-ZjB1Zm61yAZG0ozTymz8I8NqI4x0MrYrh26dq7kPgfu8vf9ImrdaOAVOFYbJ2pNAgUnVGBmbue9lTgdBOb5BbKXIpFs0fpYWqACbrQDChAA2fdX0zS9PcHu7fY8c-FOvyBVxPNYNFQuM.zip
+ ```
+
+- 解压后,您将获得类似于此结构的 Mapillary Vistas 数据集。语义分割 mask 标签在 `labels` 文件夹中。
+
+ ```none
+ mmsegmentation
+ ├── mmseg
+ ├── tools
+ ├── configs
+ ├── data
+ │ ├── mapillary
+ │ │ ├── training
+ │ │ │ ├── images
+ │ │ │ ├── v1.2
+ | │ │ │ ├── instances
+ | │ │ │ ├── labels
+ | │ │ │ └── panoptic
+ │ │ │ ├── v2.0
+ | │ │ │ ├── instances
+ | │ │ │ ├── labels
+ | │ │ │ ├── panoptic
+ | │ │ │ └── polygons
+ │ │ ├── validation
+ │ │ │ ├── images
+ | │ │ ├── v1.2
+ | │ │ │ ├── instances
+ | │ │ │ ├── labels
+ | │ │ │ └── panoptic
+ │ │ │ ├── v2.0
+ | │ │ │ ├── instances
+ | │ │ │ ├── labels
+ | │ │ │ ├── panoptic
+ | │ │ │ └── polygons
+ ```
+
+- 您可以在配置中使用 `MapillaryDataset_v1` 和 `Mapillary Dataset_v2` 设置数据集版本。
+ 在此处 [V1.2](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/_base_/datasets/mapillary_v1.py) 和 [V2.0](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/_base_/datasets/mapillary_v2.py) 查看 Mapillary Vistas 数据集配置文件
+
+## LEVIR-CD
+
+[LEVIR-CD](https://justchenhao.github.io/LEVIR/) 大规模遥感建筑变化检测数据集。
+
+数据集可以在[主页](https://justchenhao.github.io/LEVIR/)上请求获得。
+
+数据集的补充版本可以在[主页](https://github.com/S2Looking/Dataset)上请求获得。
+
+请下载数据集的补充版本,然后解压 `LEVIR-CD+.zip`,数据集的内容包括:
+
+```none
+│ ├── LEVIR-CD+
+│ │ ├── train
+│ │ │ ├── A
+│ │ │ ├── B
+│ │ │ ├── label
+│ │ ├── test
+│ │ │ ├── A
+│ │ │ ├── B
+│ │ │ ├── label
+```
+
+对于 LEVIR-CD 数据集,请运行以下命令无重叠裁剪影像:
+
+```shell
+python tools/dataset_converters/levircd.py --dataset-path /path/to/LEVIR-CD+ --out_dir /path/to/LEVIR-CD
+```
+
+裁剪后的影像大小为256x256,与原论文保持一致。
+
+## BDD100K
+
+- 可以从[官方网站](https://bdd-data.berkeley.edu/) 下载 BDD100K数据集(语义分割任务主要是10K数据集),按照官网要求注册并登陆后,数据可以在[这里](https://bdd-data.berkeley.edu/portal.html#download)找到。
+
+- 图像数据对应的名称是是`10K Images`, 语义分割标注对应的名称是`Segmentation`
+
+- 下载后,可以使用以下代码进行解压
+
+ ```bash
+ unzip ~/bdd100k_images_10k.zip -d ~/mmsegmentation/data/
+ unzip ~/bdd100k_sem_seg_labels_trainval.zip -d ~/mmsegmentation/data/
+ ```
+
+就可以得到以下文件结构了:
+
+```none
+mmsegmentation
+├── mmseg
+├── tools
+├── configs
+├── data
+│ ├── bdd100k
+│ │ ├── images
+│ │ │ └── 10k
+| │ │ │ ├── test
+| │ │ │ ├── train
+| │ │ │ └── val
+│ │ └── labels
+│ │ │ └── sem_seg
+| │ │ │ ├── colormaps
+| │ │ │ │ ├──train
+| │ │ │ │ └──val
+| │ │ │ ├── masks
+| │ │ │ │ ├──train
+| │ │ │ │ └──val
+| │ │ │ ├── polygons
+| │ │ │ │ ├──sem_seg_train.json
+| │ │ │ │ └──sem_seg_val.json
+| │ │ │ └── rles
+| │ │ │ │ ├──sem_seg_train.json
+| │ │ │ │ └──sem_seg_val.json
+```
+
+## NYU
+
+- 您可以从 [这个链接](https://drive.google.com/file/d/1wC-io-14RCIL4XTUrQLk6lBqU2AexLVp/view?usp=share_link) 下载 NYU 数据集
+
+- 下载完成后,您可以使用 [tools/dataset_converters/nyu.py](/tools/dataset_converters/nyu.py) 脚本来解压和组织数据到所需的格式
+
+ ```bash
+ python tools/dataset_converters/nyu.py nyu.zip
+ ```
diff --git a/docs/zh_cn/user_guides/3_inference.md b/docs/zh_cn/user_guides/3_inference.md
index d2fe60076f8..0afcb4b05d6 100644
--- a/docs/zh_cn/user_guides/3_inference.md
+++ b/docs/zh_cn/user_guides/3_inference.md
@@ -1,3 +1,244 @@
-## 使用预训练模型推理(待更新)
+# 教程3:使用预训练模型推理
-中文版文档支持中,请先阅读[英文版本](../../en/user_guides/3_inference.md)
+MMSegmentation 在 [Model Zoo](../Model_Zoo.md) 中为语义分割提供了预训练的模型,并支持多个标准数据集,包括 Cityscapes、ADE20K 等。
+本说明将展示如何使用现有模型对给定图像进行推理。
+关于如何在标准数据集上测试现有模型,请参阅本[指南](./4_train_test.md)
+
+MMSegmentation 为用户提供了数个接口,以便轻松使用预训练的模型进行推理。
+
+- [教程3:使用预训练模型推理](#教程3使用预训练模型推理)
+ - [推理器](#推理器)
+ - [基本使用](#基本使用)
+ - [初始化](#初始化)
+ - [可视化预测结果](#可视化预测结果)
+ - [模型列表](#模型列表)
+ - [推理 API](#推理-api)
+ - [mmseg.apis.init_model](#mmsegapisinit_model)
+ - [mmseg.apis.inference_model](#mmsegapisinference_model)
+ - [mmseg.apis.show_result_pyplot](#mmsegapisshow_result_pyplot)
+
+## 推理器
+
+在 MMSegmentation 中,我们提供了最**方便的**方式 `MMSegInferencer` 来使用模型。您只需 3 行代码就可以获得图像的分割掩膜。
+
+### 基本使用
+
+以下示例展示了如何使用 `MMSegInferencer` 对单个图像执行推理。
+
+```
+>>> from mmseg.apis import MMSegInferencer
+>>> # 将模型加载到内存中
+>>> inferencer = MMSegInferencer(model='deeplabv3plus_r18-d8_4xb2-80k_cityscapes-512x1024')
+>>> # 推理
+>>> inferencer('demo/demo.png', show=True)
+```
+
+可视化结果应如下所示:
+
+
+
+
+
+此外,您可以使用 `MMSegInferencer` 来处理一个包含多张图片的 `list`:
+
+```
+# 输入一个图片 list
+>>> images = [image1, image2, ...] # image1 可以是文件路径或 np.ndarray
+>>> inferencer(images, show=True, wait_time=0.5) # wait_time 是延迟时间,0 表示无限
+
+# 或输入图像目录
+>>> images = $IMAGESDIR
+>>> inferencer(images, show=True, wait_time=0.5)
+
+# 保存可视化渲染彩色分割图和预测结果
+# out_dir 是保存输出结果的目录,img_out_dir 和 pred_out_dir 为 out_dir 的子目录
+# 以保存可视化渲染彩色分割图和预测结果
+>>> inferencer(images, out_dir='outputs', img_out_dir='vis', pred_out_dir='pred')
+```
+
+推理器有一个可选参数 `return_datasamples`,其默认值为 False,推理器的返回值默认为 `dict` 类型,包括 'visualization' 和 'predictions' 两个 key。
+如果 `return_datasamples=True` 推理器将返回 [`SegDataSample`](../advanced_guides/structures.md) 或其列表。
+
+```
+result = inferencer('demo/demo.png')
+# 结果是一个包含 'visualization' 和 'predictions' 两个 key 的 `dict`
+# 'visualization' 包含彩色分割图
+print(result['visualization'].shape)
+# (512, 683, 3)
+
+# 'predictions' 包含带有标签索引的分割掩膜
+print(result['predictions'].shape)
+# (512, 683)
+
+result = inferencer('demo/demo.png', return_datasamples=True)
+print(type(result))
+#
+
+# 输入一个图片 list
+results = inferencer(images)
+# 输出为列表
+print(type(results['visualization']), results['visualization'][0].shape)
+# (512, 683, 3)
+print(type(results['predictions']), results['predictions'][0].shape)
+# (512, 683)
+
+results = inferencer(images, return_datasamples=True)
+#
+print(type(results[0]))
+#
+```
+
+### 初始化
+
+`MMSegInferencer` 必须使用 `model` 初始化,该 `model` 可以是模型名称或一个 `Config`,甚至可以是配置文件的路径。
+模型名称可以在模型的元文件(configs/xxx/metafile.yaml)中找到,比如 maskformer 的一个模型名称是 `maskformer_r50-d32_8xb2-160k_ade20k-512x512`,如果输入模型名称,模型的权重将自动下载。以下是其他输入参数:
+
+- weights(str,可选)- 权重的路径。如果未指定,并且模型是元文件中的模型名称,则权重将从元文件加载。默认为 None。
+- classes(list,可选)- 输入类别用于结果渲染,由于分割模型的预测结构是标签索引的分割图,`classes` 是一个相应的标签索引的类别列表。若 classes 没有定义,可视化工具将默认使用 `cityscapes` 的类别。默认为 None。
+- palette(list,可选)- 输入调色盘用于结果渲染,它是对应分类的配色列表。若 palette 没有定义,可视化工具将默认使用 `cityscapes` 的调色盘。默认为 None。
+- dataset_name(str,可选)- [数据集名称或别名](https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/utils/class_names.py#L302-L317),可视化工具将使用数据集的元信息,如类别和配色,但 `classes` 和 `palette` 具有更高的优先级。默认为 None。
+- device(str,可选)- 运行推理的设备。如果无,则会自动使用可用的设备。默认为 None。
+- scope(str,可选)- 模型的作用域。默认为 'mmseg'。
+
+### 可视化预测结果
+
+`MMSegInferencer` 有4个用于可视化预测的参数,您可以在初始化推理器时使用它们:
+
+- show(bool)- 是否弹出窗口显示图像。默认为 False。
+- wait_time(float)- 显示的间隔。默认值为 0。
+- img_out_dir(str)- `out_dir` 的子目录,用于保存渲染有色分割掩膜,因此如果要保存预测掩膜,则必须定义 `out_dir`。默认为 `vis`。
+- opacity(int,float)- 分割掩膜的透明度。默认值为 0.8。
+
+这些参数的示例请参考[基本使用](#基本使用)
+
+### 模型列表
+
+在 MMSegmentation 中有一个非常容易列出所有模型名称的方法
+
+```
+>>> from mmseg.apis import MMSegInferencer
+# models 是一个模型名称列表,它们将自动打印
+>>> models = MMSegInferencer.list_models('mmseg')
+```
+
+## 推理 API
+
+### mmseg.apis.init_model
+
+从配置文件初始化一个分割器。
+
+参数:
+
+- config(str,`Path` 或 `mmengine.Config`)- 配置文件路径或配置对象。
+- checkpoint(str,可选)- 权重路径。如果为 None,则模型将不会加载任何权重。
+- device(str,可选)- CPU/CUDA 设备选项。默认为 'cuda:0'。
+- cfg_options(dict,可选)- 用于覆盖所用配置中的某些设置的选项。
+
+返回值:
+
+- nn.Module:构建好的分割器。
+
+示例:
+
+```python
+from mmseg.apis import init_model
+
+config_path = 'configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
+checkpoint_path = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'
+
+# 初始化不带权重的模型
+model = init_model(config_path)
+
+# 初始化模型并加载权重
+model = init_model(config_path, checkpoint_path)
+
+# 在 CPU 上的初始化模型并加载权重
+model = init_model(config_path, checkpoint_path, 'cpu')
+```
+
+### mmseg.apis.inference_model
+
+使用分割器推理图像。
+
+参数:
+
+- model(nn.Module)- 加载的分割器
+- imgs(str,np.ndarray 或 list\[str/np.ndarray\])- 图像文件或加载的图像
+
+返回值:
+
+- `SegDataSample` 或 list\[`SegDataSample`\]:如果 imgs 是列表或元组,则返回相同长度的列表类型结果,否则直接返回分割结果。
+
+**注意:** [SegDataSample](https://github.com/open-mmlab/mmsegmentation/blob/1.x/mmseg/structures/seg_data_sample.py) 是 MMSegmentation 的数据结构接口,用作不同组件之间的接口。`SegDataSample` 实现抽象数据元素 `mmengine.structures.BaseDataElement`,请参阅 [MMEngine](https://github.com/open-mmlab/mmengine) 中的数据元素[文档](https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/data_element.html)了解更多信息。
+
+`SegDataSample` 中的参数分为几个部分:
+
+- `gt_sem_seg`(`PixelData`)- 语义分割的标注。
+- `pred_sem_seg`(`PixelData`)- 语义分割的预测。
+- `seg_logits`(`PixelData`)- 模型最后一层的输出结果。
+
+**注意:** [PixelData](https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/pixel_data.py) 是像素级标注或预测的数据结构,请参阅 [MMEngine](https://github.com/open-mmlab/mmengine) 中的 PixelData [文档](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/data_element.html)了解更多信息。
+
+示例:
+
+```python
+from mmseg.apis import init_model, inference_model
+
+config_path = 'configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
+checkpoint_path = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'
+img_path = 'demo/demo.png'
+
+
+model = init_model(config_path, checkpoint_path)
+result = inference_model(model, img_path)
+```
+
+### mmseg.apis.show_result_pyplot
+
+在图像上可视化分割结果。
+
+参数:
+
+- model(nn.Module)- 加载的分割器。
+- img(str 或 np.ndarray)- 图像文件名或加载的图像。
+- result(`SegDataSample`)- SegDataSample 预测结果。
+- opacity(float)- 绘制分割图的不透明度。默认值为 `0.5`,必须在 `(0,1]` 范围内。
+- title(str)- pyplot 图的标题。默认值为 ''。
+- draw_gt(bool)- 是否绘制 GT SegDataSample。默认为 `True`。
+- draw_pred(draws_pred)- 是否绘制预测 SegDataSample。默认为 `True`。
+- wait_time(float)- 显示的间隔,0 是表示“无限”的特殊值。默认为 `0`。
+- show(bool)- 是否展示绘制的图像。默认为 `True`。
+- save_dir(str,可选)- 为所有存储后端保存的文件路径。如果为 `None`,则后端存储将不会保存任何数据。
+- out_file(str,可选)- 输出文件的路径。默认为 `None`。
+
+返回值:
+
+- np.ndarray:通道为 RGB 的绘制图像。
+
+示例:
+
+```python
+from mmseg.apis import init_model, inference_model, show_result_pyplot
+
+config_path = 'configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
+checkpoint_path = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'
+img_path = 'demo/demo.png'
+
+
+# 从配置文件和权重文件构建模型
+model = init_model(config_path, checkpoint_path, device='cuda:0')
+
+# 推理给定图像
+result = inference_model(model, img_path)
+
+# 展示分割结果
+vis_image = show_result_pyplot(model, img_path, result)
+
+# 保存可视化结果,输出图像将在 `workdirs/result.png` 路径下找到
+vis_iamge = show_result_pyplot(model, img_path, result, out_file='work_dirs/result.png')
+
+# 修改展示图像的时间,注意 0 是表示“无限”的特殊值
+vis_image = show_result_pyplot(model, img_path, result, wait_time=5)
+```
+
+**注意:** 如果当前设备没有图形用户界面,建议将 `show` 设置为 `False`,并指定 `out_file` 或 `save_dir` 来保存结果。如果您想在窗口上显示结果,则不需要特殊设置。
diff --git a/docs/zh_cn/user_guides/5_deployment.md b/docs/zh_cn/user_guides/5_deployment.md
new file mode 100644
index 00000000000..b2bec028833
--- /dev/null
+++ b/docs/zh_cn/user_guides/5_deployment.md
@@ -0,0 +1,243 @@
+# 教程5:模型部署
+
+# MMSegmentation 模型部署
+
+- [教程5:模型部署](#教程5模型部署)
+- [MMSegmentation 模型部署](#mmsegmentation-模型部署)
+ - [安装](#安装)
+ - [安装 mmseg](#安装-mmseg)
+ - [安装 mmdeploy](#安装-mmdeploy)
+ - [模型转换](#模型转换)
+ - [模型规范](#模型规范)
+ - [模型推理](#模型推理)
+ - [后端模型推理](#后端模型推理)
+ - [SDK 模型推理](#sdk-模型推理)
+ - [模型支持列表](#模型支持列表)
+ - [注意事项](#注意事项)
+
+______________________________________________________________________
+
+[MMSegmentation](https://github.com/open-mmlab/mmsegmentation/tree/main) 又称`mmseg`,是一个基于 PyTorch 的开源对象分割工具箱。它是 [OpenMMLab](https://openmmlab.com/) 项目的一部分。
+
+## 安装
+
+### 安装 mmseg
+
+请参考[官网安装指南](https://mmsegmentation.readthedocs.io/en/latest/get_started.html)。
+
+### 安装 mmdeploy
+
+mmdeploy 有以下几种安装方式:
+
+**方式一:** 安装预编译包
+
+请参考[安装概述](https://mmdeploy.readthedocs.io/zh_CN/latest/get_started.html#mmdeploy)
+
+**方式二:** 一键式脚本安装
+
+如果部署平台是 **Ubuntu 18.04 及以上版本**, 请参考[脚本安装说明](../01-how-to-build/build_from_script.md),完成安装过程。
+比如,以下命令可以安装 mmdeploy 以及配套的推理引擎——`ONNX Runtime`.
+
+```shell
+git clone --recursive -b main https://github.com/open-mmlab/mmdeploy.git
+cd mmdeploy
+python3 tools/scripts/build_ubuntu_x64_ort.py $(nproc)
+export PYTHONPATH=$(pwd)/build/lib:$PYTHONPATH
+export LD_LIBRARY_PATH=$(pwd)/../mmdeploy-dep/onnxruntime-linux-x64-1.8.1/lib/:$LD_LIBRARY_PATH
+```
+
+**说明**:
+
+- 把 `$(pwd)/build/lib` 添加到 `PYTHONPATH`,目的是为了加载 mmdeploy SDK python 包 `mmdeploy_runtime`,在章节 [SDK模型推理](#sdk模型推理)中讲述其用法。
+- 在[使用 ONNX Runtime推理后端模型](#后端模型推理)时,需要加载自定义算子库,需要把 ONNX Runtime 库的路径加入环境变量 `LD_LIBRARY_PATH`中。
+
+**方式三:** 源码安装
+
+在方式一、二都满足不了的情况下,请参考[源码安装说明](../01-how-to-build/build_from_source.md) 安装 mmdeploy 以及所需推理引擎。
+
+## 模型转换
+
+你可以使用 [tools/deploy.py](https://github.com/open-mmlab/mmdeploy/tree/main/tools/deploy.py) 把 mmseg 模型一键式转换为推理后端模型。
+该工具的详细使用说明请参考[这里](https://github.com/open-mmlab/mmdeploy/tree/main/docs/en/02-how-to-run/convert_model.md#usage).
+
+以下,我们将演示如何把 `unet` 转换为 onnx 模型。
+
+```shell
+cd mmdeploy
+
+# download unet model from mmseg model zoo
+mim download mmsegmentation --config unet-s5-d16_fcn_4xb4-160k_cityscapes-512x1024 --dest .
+
+# convert mmseg model to onnxruntime model with dynamic shape
+python tools/deploy.py \
+ configs/mmseg/segmentation_onnxruntime_dynamic.py \
+ unet-s5-d16_fcn_4xb4-160k_cityscapes-512x1024.py \
+ fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes_20211210_145204-6860854e.pth \
+ demo/resources/cityscapes.png \
+ --work-dir mmdeploy_models/mmseg/ort \
+ --device cpu \
+ --show \
+ --dump-info
+```
+
+转换的关键之一是使用正确的配置文件。项目中已内置了各后端部署[配置文件](https://github.com/open-mmlab/mmdeploy/tree/main/configs/mmseg)。
+文件的命名模式是:
+
+```
+segmentation_{backend}-{precision}_{static | dynamic}_{shape}.py
+```
+
+其中:
+
+- **{backend}:** 推理后端名称。比如,onnxruntime、tensorrt、pplnn、ncnn、openvino、coreml 等等
+- **{precision}:** 推理精度。比如,fp16、int8。不填表示 fp32
+- **{static | dynamic}:** 动态、静态 shape
+- **{shape}:** 模型输入的 shape 或者 shape 范围
+
+在上例中,你也可以把 `unet` 转为其他后端模型。比如使用`segmentation_tensorrt-fp16_dynamic-512x1024-2048x2048.py`,把模型转为 tensorrt-fp16 模型。
+
+```{tip}
+当转 tensorrt 模型时, --device 需要被设置为 "cuda"
+```
+
+## 模型规范
+
+在使用转换后的模型进行推理之前,有必要了解转换结果的结构。 它存放在 `--work-dir` 指定的路路径下。
+
+上例中的`mmdeploy_models/mmseg/ort`,结构如下:
+
+```
+mmdeploy_models/mmseg/ort
+├── deploy.json
+├── detail.json
+├── end2end.onnx
+└── pipeline.json
+```
+
+重要的是:
+
+- **end2end.onnx**: 推理引擎文件。可用 ONNX Runtime 推理
+- \***.json**: mmdeploy SDK 推理所需的 meta 信息
+
+整个文件夹被定义为**mmdeploy SDK model**。换言之,**mmdeploy SDK model**既包括推理引擎,也包括推理 meta 信息。
+
+## 模型推理
+
+### 后端模型推理
+
+以上述模型转换后的 `end2end.onnx` 为例,你可以使用如下代码进行推理:
+
+```python
+from mmdeploy.apis.utils import build_task_processor
+from mmdeploy.utils import get_input_shape, load_config
+import torch
+
+deploy_cfg = 'configs/mmseg/segmentation_onnxruntime_dynamic.py'
+model_cfg = './unet-s5-d16_fcn_4xb4-160k_cityscapes-512x1024.py'
+device = 'cpu'
+backend_model = ['./mmdeploy_models/mmseg/ort/end2end.onnx']
+image = './demo/resources/cityscapes.png'
+
+# read deploy_cfg and model_cfg
+deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
+
+# build task and backend model
+task_processor = build_task_processor(model_cfg, deploy_cfg, device)
+model = task_processor.build_backend_model(backend_model)
+
+# process input image
+input_shape = get_input_shape(deploy_cfg)
+model_inputs, _ = task_processor.create_input(image, input_shape)
+
+# do model inference
+with torch.no_grad():
+ result = model.test_step(model_inputs)
+
+# visualize results
+task_processor.visualize(
+ image=image,
+ model=model,
+ result=result[0],
+ window_name='visualize',
+ output_file='./output_segmentation.png')
+```
+
+### SDK 模型推理
+
+你也可以参考如下代码,对 SDK model 进行推理:
+
+```python
+from mmdeploy_runtime import Segmentor
+import cv2
+import numpy as np
+
+img = cv2.imread('./demo/resources/cityscapes.png')
+
+# create a classifier
+segmentor = Segmentor(model_path='./mmdeploy_models/mmseg/ort', device_name='cpu', device_id=0)
+# perform inference
+seg = segmentor(img)
+
+# visualize inference result
+## random a palette with size 256x3
+palette = np.random.randint(0, 256, size=(256, 3))
+color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
+for label, color in enumerate(palette):
+ color_seg[seg == label, :] = color
+# convert to BGR
+color_seg = color_seg[..., ::-1]
+img = img * 0.5 + color_seg * 0.5
+img = img.astype(np.uint8)
+cv2.imwrite('output_segmentation.png', img)
+```
+
+除了python API,mmdeploy SDK 还提供了诸如 C、C++、C#、Java等多语言接口。
+你可以参考[样例](https://github.com/open-mmlab/mmdeploy/tree/main/demo)学习其他语言接口的使用方法。
+
+## 模型支持列表
+
+| Model | TorchScript | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVino |
+| :-------------------------------------------------------------------------------------------------------- | :---------: | :---------: | :------: | :--: | :---: | :------: |
+| [FCN](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/fcn) | Y | Y | Y | Y | Y | Y |
+| [PSPNet](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/pspnet)[\*](#static_shape) | Y | Y | Y | Y | Y | Y |
+| [DeepLabV3](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/deeplabv3) | Y | Y | Y | Y | Y | Y |
+| [DeepLabV3+](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/deeplabv3plus) | Y | Y | Y | Y | Y | Y |
+| [Fast-SCNN](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/fastscnn)[\*](#static_shape) | Y | Y | Y | N | Y | Y |
+| [UNet](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/unet) | Y | Y | Y | Y | Y | Y |
+| [ANN](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/ann)[\*](#static_shape) | Y | Y | Y | N | N | N |
+| [APCNet](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/apcnet) | Y | Y | Y | Y | N | N |
+| [BiSeNetV1](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/bisenetv1) | Y | Y | Y | Y | N | Y |
+| [BiSeNetV2](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/bisenetv2) | Y | Y | Y | Y | N | Y |
+| [CGNet](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/cgnet) | Y | Y | Y | Y | N | Y |
+| [DMNet](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/dmnet) | ? | Y | N | N | N | N |
+| [DNLNet](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/dnlnet) | ? | Y | Y | Y | N | Y |
+| [EMANet](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/emanet) | Y | Y | Y | N | N | Y |
+| [EncNet](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/encnet) | Y | Y | Y | N | N | Y |
+| [ERFNet](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/erfnet) | Y | Y | Y | Y | N | Y |
+| [FastFCN](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/fastfcn) | Y | Y | Y | Y | N | Y |
+| [GCNet](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/gcnet) | Y | Y | Y | N | N | N |
+| [ICNet](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/icnet)[\*](#static_shape) | Y | Y | Y | N | N | Y |
+| [ISANet](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/isanet)[\*](#static_shape) | N | Y | Y | N | N | Y |
+| [NonLocal Net](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/nonlocal_net) | ? | Y | Y | Y | N | Y |
+| [OCRNet](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/ocrnet) | Y | Y | Y | Y | N | Y |
+| [PointRend](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/point_rend)[\*](#static_shape) | Y | Y | Y | N | N | N |
+| [Semantic FPN](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/sem_fpn) | Y | Y | Y | Y | N | Y |
+| [STDC](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/stdc) | Y | Y | Y | Y | N | Y |
+| [UPerNet](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/upernet)[\*](#static_shape) | N | Y | Y | N | N | N |
+| [DANet](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/danet) | ? | Y | Y | N | N | Y |
+| [Segmenter](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/segmenter)[\*](#static_shape) | N | Y | Y | Y | N | Y |
+| [SegFormer](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/segformer)[\*](#static_shape) | ? | Y | Y | N | N | Y |
+| [SETR](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/setr) | ? | Y | N | N | N | Y |
+| [CCNet](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/ccnet) | ? | N | N | N | N | N |
+| [PSANet](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/psanet) | ? | N | N | N | N | N |
+| [DPT](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/dpt) | ? | N | N | N | N | N |
+
+## 注意事项
+
+- 所有 mmseg 模型仅支持 "whole" 推理模式。
+
+- PSPNet,Fast-SCNN 仅支持静态输入,因为多数推理框架的 [nn.AdaptiveAvgPool2d](https://github.com/open-mmlab/mmsegmentation/blob/0c87f7a0c9099844eff8e90fa3db5b0d0ca02fee/mmseg/models/decode_heads/psp_head.py#L38) 不支持动态输入。
+
+- 对于仅支持静态形状的模型,应使用静态形状的部署配置文件,例如 `configs/mmseg/segmentation_tensorrt_static-1024x2048.py`
+
+- 对于喜欢部署模型生成概率特征图的用户,将 `codebase_config = dict(with_argmax=False)` 放在部署配置中就足够了。
diff --git a/docs/zh_cn/user_guides/deployment.md b/docs/zh_cn/user_guides/deployment.md
deleted file mode 100644
index f98110c8b5f..00000000000
--- a/docs/zh_cn/user_guides/deployment.md
+++ /dev/null
@@ -1 +0,0 @@
-# 模型部署
diff --git a/mmseg/apis/__init__.py b/mmseg/apis/__init__.py
index d22dc3f0ada..b50a266319c 100644
--- a/mmseg/apis/__init__.py
+++ b/mmseg/apis/__init__.py
@@ -1,7 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .inference import inference_model, init_model, show_result_pyplot
from .mmseg_inferencer import MMSegInferencer
+from .remote_sense_inferencer import RSImage, RSInferencer
__all__ = [
- 'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer'
+ 'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer',
+ 'RSInferencer', 'RSImage'
]
diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py
index 4aadffc7982..0dd70cd6155 100644
--- a/mmseg/apis/inference.py
+++ b/mmseg/apis/inference.py
@@ -1,14 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
-from collections import defaultdict
from pathlib import Path
-from typing import Optional, Sequence, Union
+from typing import Optional, Union
import mmcv
import numpy as np
import torch
from mmengine import Config
-from mmengine.dataset import Compose
from mmengine.registry import init_default_scope
from mmengine.runner import load_checkpoint
from mmengine.utils import mkdir_or_exist
@@ -18,6 +16,7 @@
from mmseg.structures import SegDataSample
from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette
from mmseg.visualization import SegLocalVisualizer
+from .utils import ImageType, _preprare_data
def init_model(config: Union[str, Path, Config],
@@ -90,41 +89,6 @@ def init_model(config: Union[str, Path, Config],
return model
-ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
-
-
-def _preprare_data(imgs: ImageType, model: BaseSegmentor):
-
- cfg = model.cfg
- for t in cfg.test_pipeline:
- if t.get('type') == 'LoadAnnotations':
- cfg.test_pipeline.remove(t)
-
- is_batch = True
- if not isinstance(imgs, (list, tuple)):
- imgs = [imgs]
- is_batch = False
-
- if isinstance(imgs[0], np.ndarray):
- cfg.test_pipeline[0]['type'] = 'LoadImageFromNDArray'
-
- # TODO: Consider using the singleton pattern to avoid building
- # a pipeline for each inference
- pipeline = Compose(cfg.test_pipeline)
-
- data = defaultdict(list)
- for img in imgs:
- if isinstance(img, np.ndarray):
- data_ = dict(img=img)
- else:
- data_ = dict(img_path=img)
- data_ = pipeline(data_)
- data['inputs'].append(data_['inputs'])
- data['data_samples'].append(data_['data_samples'])
-
- return data, is_batch
-
-
def inference_model(model: BaseSegmentor,
img: ImageType) -> Union[SegDataSample, SampleList]:
"""Inference image(s) with the segmentor.
@@ -158,6 +122,7 @@ def show_result_pyplot(model: BaseSegmentor,
draw_pred: bool = True,
wait_time: float = 0,
show: bool = True,
+ withLabels: Optional[bool] = True,
save_dir=None,
out_file=None):
"""Visualize the segmentation results on the image.
@@ -177,17 +142,21 @@ def show_result_pyplot(model: BaseSegmentor,
that means "forever". Defaults to 0.
show (bool): Whether to display the drawn image.
Default to True.
+ withLabels(bool, optional): Add semantic labels in visualization
+ result, Default to True.
save_dir (str, optional): Save file dir for all storage backends.
If it is None, the backend storage will not save any data.
out_file (str, optional): Path to output file. Default to None.
+
+
Returns:
np.ndarray: the drawn image which channel is RGB.
"""
if hasattr(model, 'module'):
model = model.module
if isinstance(img, str):
- image = mmcv.imread(img)
+ image = mmcv.imread(img, channel_order='rgb')
else:
image = img
if save_dir is not None:
@@ -208,7 +177,8 @@ def show_result_pyplot(model: BaseSegmentor,
draw_pred=draw_pred,
wait_time=wait_time,
out_file=out_file,
- show=show)
+ show=show,
+ withLabels=withLabels)
vis_img = visualizer.get_image()
return vis_img
diff --git a/mmseg/apis/mmseg_inferencer.py b/mmseg/apis/mmseg_inferencer.py
index cb387b10b3f..095639a80fd 100644
--- a/mmseg/apis/mmseg_inferencer.py
+++ b/mmseg/apis/mmseg_inferencer.py
@@ -30,7 +30,7 @@ class MMSegInferencer(BaseInferencer):
Args:
model (str, optional): Path to the config file or the model name
- defined in metafile. Take the `mmseg metafile `_
+ defined in metafile. Take the `mmseg metafile `_
as an example the `model` could be
"fcn_r50-d8_4xb2-40k_cityscapes-512x1024", and the weights of model
will be download automatically. If use config file, like
@@ -48,7 +48,7 @@ class MMSegInferencer(BaseInferencer):
a list of color palette responding to the classes. If palette is
not defined, visualizer will take `cityscapes` palette by default.
Defaults to None.
- dataset_name (str, optional): `Dataset name or alias `_
+ dataset_name (str, optional): `Dataset name or alias `_
visulizer will use the meta information of the dataset i.e. classes
and palette, but the `classes` and `palette` have higher priority.
Defaults to None.
@@ -59,7 +59,9 @@ class MMSegInferencer(BaseInferencer):
preprocess_kwargs: set = set()
forward_kwargs: set = {'mode', 'out_dir'}
- visualize_kwargs: set = {'show', 'wait_time', 'img_out_dir', 'opacity'}
+ visualize_kwargs: set = {
+ 'show', 'wait_time', 'img_out_dir', 'opacity', 'return_vis'
+ }
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
def __init__(self,
@@ -82,7 +84,7 @@ def __init__(self,
self.model = revert_sync_batchnorm(self.model)
assert isinstance(self.visualizer, SegLocalVisualizer)
- self.visualizer.set_dataset_meta(palette, classes, dataset_name)
+ self.visualizer.set_dataset_meta(classes, palette, dataset_name)
def _load_weights_to_model(self, model: nn.Module,
checkpoint: Optional[dict],
@@ -137,6 +139,7 @@ def __call__(self,
inputs: InputsType,
return_datasamples: bool = False,
batch_size: int = 1,
+ return_vis: bool = False,
show: bool = False,
wait_time: int = 0,
out_dir: str = '',
@@ -188,11 +191,13 @@ def __call__(self,
wait_time=wait_time,
img_out_dir=img_out_dir,
pred_out_dir=pred_out_dir,
+ return_vis=return_vis,
**kwargs)
def visualize(self,
inputs: list,
preds: List[dict],
+ return_vis: bool = False,
show: bool = False,
wait_time: int = 0,
img_out_dir: str = '',
@@ -213,12 +218,12 @@ def visualize(self,
Returns:
List[np.ndarray]: Visualization results.
"""
- if self.visualizer is None or (not show and img_out_dir == ''):
+ if not show and img_out_dir == '' and not return_vis:
return None
-
- if getattr(self, 'visualizer') is None:
+ if self.visualizer is None:
raise ValueError('Visualization needs the "visualizer" term'
- 'defined in the config, but got None')
+ 'defined in the config, but got None.')
+
self.visualizer.set_dataset_meta(**self.model.dataset_meta)
self.visualizer.alpha = opacity
@@ -250,10 +255,11 @@ def visualize(self,
draw_gt=False,
draw_pred=True,
out_file=out_file)
- results.append(self.visualizer.get_image())
+ if return_vis:
+ results.append(self.visualizer.get_image())
self.num_visualized_imgs += 1
- return results
+ return results if return_vis else None
def postprocess(self,
preds: PredType,
@@ -300,17 +306,28 @@ def postprocess(self,
results_dict['visualization'] = []
for i, pred in enumerate(preds):
- pred_data = pred.pred_sem_seg.numpy().data[0]
- results_dict['predictions'].append(pred_data)
+ pred_data = dict()
+ if 'pred_sem_seg' in pred.keys():
+ pred_data['sem_seg'] = pred.pred_sem_seg.numpy().data[0]
+ elif 'pred_depth_map' in pred.keys():
+ pred_data['depth_map'] = pred.pred_depth_map.numpy().data[0]
+
if visualization is not None:
vis = visualization[i]
results_dict['visualization'].append(vis)
if pred_out_dir != '':
mmengine.mkdir_or_exist(pred_out_dir)
- img_name = str(self.num_pred_imgs).zfill(8) + '_pred.png'
- img_path = osp.join(pred_out_dir, img_name)
- output = Image.fromarray(pred_data.astype(np.uint8))
- output.save(img_path)
+ for key, data in pred_data.items():
+ post_fix = '_pred.png' if key == 'sem_seg' else '_pred.npy'
+ img_name = str(self.num_pred_imgs).zfill(8) + post_fix
+ img_path = osp.join(pred_out_dir, img_name)
+ if key == 'sem_seg':
+ output = Image.fromarray(data.astype(np.uint8))
+ output.save(img_path)
+ else:
+ np.save(img_path, data)
+ pred_data = next(iter(pred_data.values()))
+ results_dict['predictions'].append(pred_data)
self.num_pred_imgs += 1
if len(results_dict['predictions']) == 1:
@@ -338,12 +355,13 @@ def preprocess(self, inputs, batch_size, **kwargs):
"""
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
# Loading annotations is also not applicable
- idx = self._get_transform_idx(pipeline_cfg, 'LoadAnnotations')
- if idx != -1:
- del pipeline_cfg[idx]
+ for transform in ('LoadAnnotations', 'LoadDepthAnnotation'):
+ idx = self._get_transform_idx(pipeline_cfg, transform)
+ if idx != -1:
+ del pipeline_cfg[idx]
+
load_img_idx = self._get_transform_idx(pipeline_cfg,
'LoadImageFromFile')
-
if load_img_idx == -1:
raise ValueError(
'LoadImageFromFile is not found in the test pipeline')
diff --git a/mmseg/apis/remote_sense_inferencer.py b/mmseg/apis/remote_sense_inferencer.py
new file mode 100644
index 00000000000..6726c6ae346
--- /dev/null
+++ b/mmseg/apis/remote_sense_inferencer.py
@@ -0,0 +1,279 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import threading
+from queue import Queue
+from typing import List, Optional, Tuple
+
+import numpy as np
+import torch
+from mmengine import Config
+from mmengine.model import BaseModel
+from mmengine.registry import init_default_scope
+from mmengine.runner import load_checkpoint
+
+try:
+ from osgeo import gdal
+except ImportError:
+ gdal = None
+
+from mmseg.registry import MODELS
+from .utils import _preprare_data
+
+
+class RSImage:
+ """Remote sensing image class.
+
+ Args:
+ img (str or gdal.Dataset): Image file path or gdal.Dataset.
+ """
+
+ def __init__(self, image):
+ self.dataset = gdal.Open(image, gdal.GA_ReadOnly) if isinstance(
+ image, str) else image
+ assert isinstance(self.dataset, gdal.Dataset), \
+ f'{image} is not a image'
+ self.width = self.dataset.RasterXSize
+ self.height = self.dataset.RasterYSize
+ self.channel = self.dataset.RasterCount
+ self.trans = self.dataset.GetGeoTransform()
+ self.proj = self.dataset.GetProjection()
+ self.band_list = []
+ self.band_list.extend(
+ self.dataset.GetRasterBand(c + 1) for c in range(self.channel))
+ self.grids = []
+
+ def read(self, grid: Optional[List] = None) -> np.ndarray:
+ """Read image data. If grid is None, read the whole image.
+
+ Args:
+ grid (Optional[List], optional): Grid to read. Defaults to None.
+ Returns:
+ np.ndarray: Image data.
+ """
+ if grid is None:
+ return np.einsum('ijk->jki', self.dataset.ReadAsArray())
+ assert len(
+ grid) >= 4, 'grid must be a list containing at least 4 elements'
+ data = self.dataset.ReadAsArray(*grid[:4])
+ if data.ndim == 2:
+ data = data[np.newaxis, ...]
+ return np.einsum('ijk->jki', data)
+
+ def write(self, data: Optional[np.ndarray], grid: Optional[List] = None):
+ """Write image data.
+
+ Args:
+ grid (Optional[List], optional): Grid to write. Defaults to None.
+ data (Optional[np.ndarray], optional): Data to write.
+ Defaults to None.
+
+ Raises:
+ ValueError: Either grid or data must be provided.
+ """
+ if grid is not None:
+ assert len(grid) == 8, 'grid must be a list of 8 elements'
+ for band in self.band_list:
+ band.WriteArray(
+ data[grid[5]:grid[5] + grid[7], grid[4]:grid[4] + grid[6]],
+ grid[0] + grid[4], grid[1] + grid[5])
+ elif data is not None:
+ for i in range(self.channel):
+ self.band_list[i].WriteArray(data[..., i])
+ else:
+ raise ValueError('Either grid or data must be provided.')
+
+ def create_seg_map(self, output_path: Optional[str] = None):
+ if output_path is None:
+ output_path = 'output_label.tif'
+ driver = gdal.GetDriverByName('GTiff')
+ seg_map = driver.Create(output_path, self.width, self.height, 1,
+ gdal.GDT_Byte)
+ seg_map.SetGeoTransform(self.trans)
+ seg_map.SetProjection(self.proj)
+ seg_map_img = RSImage(seg_map)
+ seg_map_img.path = output_path
+ return seg_map_img
+
+ def create_grids(self,
+ window_size: Tuple[int, int],
+ stride: Tuple[int, int] = (0, 0)):
+ """Create grids for image inference.
+
+ Args:
+ window_size (Tuple[int, int]): the size of the sliding window.
+ stride (Tuple[int, int], optional): the stride of the sliding
+ window. Defaults to (0, 0).
+
+ Raises:
+ AssertionError: window_size must be a tuple of 2 elements.
+ AssertionError: stride must be a tuple of 2 elements.
+ """
+ assert len(
+ window_size) == 2, 'window_size must be a tuple of 2 elements'
+ assert len(stride) == 2, 'stride must be a tuple of 2 elements'
+ win_w, win_h = window_size
+ stride_x, stride_y = stride
+
+ stride_x = win_w if stride_x == 0 else stride_x
+ stride_y = win_h if stride_y == 0 else stride_y
+
+ x_half_overlap = (win_w - stride_x + 1) // 2
+ y_half_overlap = (win_h - stride_y + 1) // 2
+
+ for y in range(0, self.height, stride_y):
+ y_end = y + win_h >= self.height
+ y_offset = self.height - win_h if y_end else y
+ y_size = win_h
+ y_crop_off = 0 if y_offset == 0 else y_half_overlap
+ y_crop_size = y_size if y_end else win_h - y_crop_off
+
+ for x in range(0, self.width, stride_x):
+ x_end = x + win_w >= self.width
+ x_offset = self.width - win_w if x_end else x
+ x_size = win_w
+ x_crop_off = 0 if x_offset == 0 else x_half_overlap
+ x_crop_size = x_size if x_end else win_w - x_crop_off
+
+ self.grids.append([
+ x_offset, y_offset, x_size, y_size, x_crop_off, y_crop_off,
+ x_crop_size, y_crop_size
+ ])
+
+
+class RSInferencer:
+ """Remote sensing inference class.
+
+ Args:
+ model (BaseModel): The loaded model.
+ batch_size (int, optional): Batch size. Defaults to 1.
+ thread (int, optional): Number of threads. Defaults to 1.
+ """
+
+ def __init__(self, model: BaseModel, batch_size: int = 1, thread: int = 1):
+ self.model = model
+ self.batch_size = batch_size
+ self.END_FLAG = object()
+ self.read_buffer = Queue(self.batch_size)
+ self.write_buffer = Queue(self.batch_size)
+ self.thread = thread
+
+ @classmethod
+ def from_config_path(cls,
+ config_path: str,
+ checkpoint_path: str,
+ batch_size: int = 1,
+ thread: int = 1,
+ device: Optional[str] = 'cpu'):
+ """Initialize a segmentor from config file.
+
+ Args:
+ config_path (str): Config file path.
+ checkpoint_path (str): Checkpoint path.
+ batch_size (int, optional): Batch size. Defaults to 1.
+ """
+ init_default_scope('mmseg')
+ cfg = Config.fromfile(config_path)
+ model = MODELS.build(cfg.model)
+ model.cfg = cfg
+ load_checkpoint(model, checkpoint_path, map_location='cpu')
+ model.to(device)
+ model.eval()
+ return cls(model, batch_size, thread)
+
+ @classmethod
+ def from_model(cls,
+ model: BaseModel,
+ checkpoint_path: Optional[str] = None,
+ batch_size: int = 1,
+ thread: int = 1,
+ device: Optional[str] = 'cpu'):
+ """Initialize a segmentor from model.
+
+ Args:
+ model (BaseModel): The loaded model.
+ checkpoint_path (Optional[str]): Checkpoint path.
+ batch_size (int, optional): Batch size. Defaults to 1.
+ """
+ if checkpoint_path is not None:
+ load_checkpoint(model, checkpoint_path, map_location='cpu')
+ model.to(device)
+ return cls(model, batch_size, thread)
+
+ def read(self,
+ image: RSImage,
+ window_size: Tuple[int, int],
+ strides: Tuple[int, int] = (0, 0)):
+ """Load image data to read buffer.
+
+ Args:
+ image (RSImage): The image to read.
+ window_size (Tuple[int, int]): The size of the sliding window.
+ strides (Tuple[int, int], optional): The stride of the sliding
+ window. Defaults to (0, 0).
+ """
+ image.create_grids(window_size, strides)
+ for grid in image.grids:
+ self.read_buffer.put([grid, image.read(grid=grid)])
+ self.read_buffer.put(self.END_FLAG)
+
+ def inference(self):
+ """Inference image data from read buffer and put the result to write
+ buffer."""
+ while True:
+ item = self.read_buffer.get()
+ if item == self.END_FLAG:
+ self.read_buffer.put(self.END_FLAG)
+ self.write_buffer.put(item)
+ break
+ data, _ = _preprare_data(item[1], self.model)
+ with torch.no_grad():
+ result = self.model.test_step(data)
+ item[1] = result[0].pred_sem_seg.cpu().data.numpy()[0]
+ self.write_buffer.put(item)
+ self.read_buffer.task_done()
+
+ def write(self, image: RSImage, output_path: Optional[str] = None):
+ """Write image data from write buffer.
+
+ Args:
+ image (RSImage): The image to write.
+ output_path (Optional[str], optional): The path to save the
+ segmentation map. Defaults to None.
+ """
+ seg_map = image.create_seg_map(output_path)
+ while True:
+ item = self.write_buffer.get()
+ if item == self.END_FLAG:
+ break
+ seg_map.write(data=item[1], grid=item[0])
+ self.write_buffer.task_done()
+
+ def run(self,
+ image: RSImage,
+ window_size: Tuple[int, int],
+ strides: Tuple[int, int] = (0, 0),
+ output_path: Optional[str] = None):
+ """Run inference with multi-threading.
+
+ Args:
+ image (RSImage): The image to inference.
+ window_size (Tuple[int, int]): The size of the sliding window.
+ strides (Tuple[int, int], optional): The stride of the sliding
+ window. Defaults to (0, 0).
+ output_path (Optional[str], optional): The path to save the
+ segmentation map. Defaults to None.
+ """
+ read_thread = threading.Thread(
+ target=self.read, args=(image, window_size, strides))
+ read_thread.start()
+ inference_threads = []
+ for _ in range(self.thread):
+ inference_thread = threading.Thread(target=self.inference)
+ inference_thread.start()
+ inference_threads.append(inference_thread)
+ write_thread = threading.Thread(
+ target=self.write, args=(image, output_path))
+ write_thread.start()
+ read_thread.join()
+ for inference_thread in inference_threads:
+ inference_thread.join()
+ write_thread.join()
diff --git a/mmseg/apis/utils.py b/mmseg/apis/utils.py
new file mode 100644
index 00000000000..4cf87756602
--- /dev/null
+++ b/mmseg/apis/utils.py
@@ -0,0 +1,41 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import defaultdict
+from typing import Sequence, Union
+
+import numpy as np
+from mmengine.dataset import Compose
+from mmengine.model import BaseModel
+
+ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
+
+
+def _preprare_data(imgs: ImageType, model: BaseModel):
+
+ cfg = model.cfg
+ for t in cfg.test_pipeline:
+ if t.get('type') == 'LoadAnnotations':
+ cfg.test_pipeline.remove(t)
+
+ is_batch = True
+ if not isinstance(imgs, (list, tuple)):
+ imgs = [imgs]
+ is_batch = False
+
+ if isinstance(imgs[0], np.ndarray):
+ cfg.test_pipeline[0]['type'] = 'LoadImageFromNDArray'
+
+ # TODO: Consider using the singleton pattern to avoid building
+ # a pipeline for each inference
+ pipeline = Compose(cfg.test_pipeline)
+
+ data = defaultdict(list)
+ for img in imgs:
+ if isinstance(img, np.ndarray):
+ data_ = dict(img=img)
+ else:
+ data_ = dict(img_path=img)
+ data_ = pipeline(data_)
+ data['inputs'].append(data_['inputs'])
+ data['data_samples'].append(data_['data_samples'])
+
+ return data, is_batch
diff --git a/mmseg/datasets/__init__.py b/mmseg/datasets/__init__.py
index a90d53c88e0..a2bdb63d016 100644
--- a/mmseg/datasets/__init__.py
+++ b/mmseg/datasets/__init__.py
@@ -1,7 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
# yapf: disable
from .ade import ADE20KDataset
-from .basesegdataset import BaseSegDataset
+from .basesegdataset import BaseCDDataset, BaseSegDataset
+from .bdd100k import BDD100KDataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .coco_stuff import COCOStuffDataset
@@ -9,29 +10,34 @@
from .dataset_wrappers import MultiImageMixDataset
from .decathlon import DecathlonDataset
from .drive import DRIVEDataset
+from .dsdl import DSDLSegDataset
from .hrf import HRFDataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
+from .levir import LEVIRCDDataset
from .lip import LIPDataset
from .loveda import LoveDADataset
from .mapillary import MapillaryDataset_v1, MapillaryDataset_v2
from .night_driving import NightDrivingDataset
+from .nyu import NYUDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .refuge import REFUGEDataset
from .stare import STAREDataset
from .synapse import SynapseDataset
# yapf: disable
-from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,
+from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
- BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
- LoadBiomedicalAnnotation, LoadBiomedicalData,
- LoadBiomedicalImageFromFile, LoadImageFromNDArray,
- PackSegInputs, PhotoMetricDistortion, RandomCrop,
- RandomCutOut, RandomMosaic, RandomRotate,
- RandomRotFlip, Rerange, ResizeShortestEdge,
- ResizeToMultiple, RGB2Gray, SegRescale)
+ BioMedicalRandomGamma, ConcatCDInput, GenerateEdge,
+ LoadAnnotations, LoadBiomedicalAnnotation,
+ LoadBiomedicalData, LoadBiomedicalImageFromFile,
+ LoadImageFromNDArray, LoadMultipleRSImageFromFile,
+ LoadSingleRSImageFromFile, PackSegInputs,
+ PhotoMetricDistortion, RandomCrop, RandomCutOut,
+ RandomMosaic, RandomRotate, RandomRotFlip, Rerange,
+ ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
+ SegRescale)
from .voc import PascalVOCDataset
# yapf: enable
@@ -51,5 +57,8 @@
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
'SynapseDataset', 'REFUGEDataset', 'MapillaryDataset_v1',
- 'MapillaryDataset_v2'
+ 'MapillaryDataset_v2', 'Albu', 'LEVIRCDDataset',
+ 'LoadMultipleRSImageFromFile', 'LoadSingleRSImageFromFile',
+ 'ConcatCDInput', 'BaseCDDataset', 'DSDLSegDataset', 'BDD100KDataset',
+ 'NYUDataset'
]
diff --git a/mmseg/datasets/basesegdataset.py b/mmseg/datasets/basesegdataset.py
index ddf476bae94..9c4668c1f56 100644
--- a/mmseg/datasets/basesegdataset.py
+++ b/mmseg/datasets/basesegdataset.py
@@ -235,7 +235,9 @@ def load_data_list(self) -> List[dict]:
data_list = []
img_dir = self.data_prefix.get('img_path', None)
ann_dir = self.data_prefix.get('seg_map_path', None)
- if osp.isfile(self.ann_file):
+ if not osp.isdir(self.ann_file) and self.ann_file:
+ assert osp.isfile(self.ann_file), \
+ f'Failed to load `ann_file` {self.ann_file}'
lines = mmengine.list_from_file(
self.ann_file, backend_args=self.backend_args)
for line in lines:
@@ -250,6 +252,7 @@ def load_data_list(self) -> List[dict]:
data_info['seg_fields'] = []
data_list.append(data_info)
else:
+ _suffix_len = len(self.img_suffix)
for img in fileio.list_dir_or_file(
dir_path=img_dir,
list_dir=False,
@@ -258,7 +261,288 @@ def load_data_list(self) -> List[dict]:
backend_args=self.backend_args):
data_info = dict(img_path=osp.join(img_dir, img))
if ann_dir is not None:
- seg_map = img.replace(self.img_suffix, self.seg_map_suffix)
+ seg_map = img[:-_suffix_len] + self.seg_map_suffix
+ data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
+ data_info['label_map'] = self.label_map
+ data_info['reduce_zero_label'] = self.reduce_zero_label
+ data_info['seg_fields'] = []
+ data_list.append(data_info)
+ data_list = sorted(data_list, key=lambda x: x['img_path'])
+ return data_list
+
+
+@DATASETS.register_module()
+class BaseCDDataset(BaseDataset):
+ """Custom dataset for change detection. An example of file structure is as
+ followed.
+
+ .. code-block:: none
+
+ ├── data
+ │ ├── my_dataset
+ │ │ ├── img_dir
+ │ │ │ ├── train
+ │ │ │ │ ├── xxx{img_suffix}
+ │ │ │ │ ├── yyy{img_suffix}
+ │ │ │ │ ├── zzz{img_suffix}
+ │ │ │ ├── val
+ │ │ ├── img_dir2
+ │ │ │ ├── train
+ │ │ │ │ ├── xxx{img_suffix}
+ │ │ │ │ ├── yyy{img_suffix}
+ │ │ │ │ ├── zzz{img_suffix}
+ │ │ │ ├── val
+ │ │ ├── ann_dir
+ │ │ │ ├── train
+ │ │ │ │ ├── xxx{seg_map_suffix}
+ │ │ │ │ ├── yyy{seg_map_suffix}
+ │ │ │ │ ├── zzz{seg_map_suffix}
+ │ │ │ ├── val
+
+ The image names in img_dir and img_dir2 should be consistent.
+ The img/gt_semantic_seg pair of BaseSegDataset should be of the same
+ except suffix. A valid img/gt_semantic_seg filename pair should be like
+ ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
+ in the suffix). If split is given, then ``xxx`` is specified in txt file.
+ Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
+ Please refer to ``docs/en/tutorials/new_dataset.md`` for more details.
+
+
+ Args:
+ ann_file (str): Annotation file path. Defaults to ''.
+ metainfo (dict, optional): Meta information for dataset, such as
+ specify classes to load. Defaults to None.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Defaults to None.
+ data_prefix (dict, optional): Prefix for training data. Defaults to
+ dict(img_path=None, img_path2=None, seg_map_path=None).
+ img_suffix (str): Suffix of images. Default: '.jpg'
+ img_suffix2 (str): Suffix of images. Default: '.jpg'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ filter_cfg (dict, optional): Config for filter data. Defaults to None.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Defaults to None which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy. Defaults
+ to True.
+ pipeline (list, optional): Processing pipeline. Defaults to [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Defaults to False.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=True``. Defaults to False.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Defaults to 1000.
+ ignore_index (int): The label index to be ignored. Default: 255
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
+ Default to False.
+ backend_args (dict, Optional): Arguments to instantiate a file backend.
+ See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
+ for details. Defaults to None.
+ Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
+ """
+ METAINFO: dict = dict()
+
+ def __init__(self,
+ ann_file: str = '',
+ img_suffix='.jpg',
+ img_suffix2='.jpg',
+ seg_map_suffix='.png',
+ metainfo: Optional[dict] = None,
+ data_root: Optional[str] = None,
+ data_prefix: dict = dict(
+ img_path='', img_path2='', seg_map_path=''),
+ filter_cfg: Optional[dict] = None,
+ indices: Optional[Union[int, Sequence[int]]] = None,
+ serialize_data: bool = True,
+ pipeline: List[Union[dict, Callable]] = [],
+ test_mode: bool = False,
+ lazy_init: bool = False,
+ max_refetch: int = 1000,
+ ignore_index: int = 255,
+ reduce_zero_label: bool = False,
+ backend_args: Optional[dict] = None) -> None:
+
+ self.img_suffix = img_suffix
+ self.img_suffix2 = img_suffix2
+ self.seg_map_suffix = seg_map_suffix
+ self.ignore_index = ignore_index
+ self.reduce_zero_label = reduce_zero_label
+ self.backend_args = backend_args.copy() if backend_args else None
+
+ self.data_root = data_root
+ self.data_prefix = copy.copy(data_prefix)
+ self.ann_file = ann_file
+ self.filter_cfg = copy.deepcopy(filter_cfg)
+ self._indices = indices
+ self.serialize_data = serialize_data
+ self.test_mode = test_mode
+ self.max_refetch = max_refetch
+ self.data_list: List[dict] = []
+ self.data_bytes: np.ndarray
+
+ # Set meta information.
+ self._metainfo = self._load_metainfo(copy.deepcopy(metainfo))
+
+ # Get label map for custom classes
+ new_classes = self._metainfo.get('classes', None)
+ self.label_map = self.get_label_map(new_classes)
+ self._metainfo.update(
+ dict(
+ label_map=self.label_map,
+ reduce_zero_label=self.reduce_zero_label))
+
+ # Update palette based on label map or generate palette
+ # if it is not defined
+ updated_palette = self._update_palette()
+ self._metainfo.update(dict(palette=updated_palette))
+
+ # Join paths.
+ if self.data_root is not None:
+ self._join_prefix()
+
+ # Build pipeline.
+ self.pipeline = Compose(pipeline)
+ # Full initialize the dataset.
+ if not lazy_init:
+ self.full_init()
+
+ if test_mode:
+ assert self._metainfo.get('classes') is not None, \
+ 'dataset metainfo `classes` should be specified when testing'
+
+ @classmethod
+ def get_label_map(cls,
+ new_classes: Optional[Sequence] = None
+ ) -> Union[Dict, None]:
+ """Require label mapping.
+
+ The ``label_map`` is a dictionary, its keys are the old label ids and
+ its values are the new label ids, and is used for changing pixel
+ labels in load_annotations. If and only if old classes in cls.METAINFO
+ is not equal to new classes in self._metainfo and nether of them is not
+ None, `label_map` is not None.
+
+ Args:
+ new_classes (list, tuple, optional): The new classes name from
+ metainfo. Default to None.
+
+
+ Returns:
+ dict, optional: The mapping from old classes in cls.METAINFO to
+ new classes in self._metainfo
+ """
+ old_classes = cls.METAINFO.get('classes', None)
+ if (new_classes is not None and old_classes is not None
+ and list(new_classes) != list(old_classes)):
+
+ label_map = {}
+ if not set(new_classes).issubset(cls.METAINFO['classes']):
+ raise ValueError(
+ f'new classes {new_classes} is not a '
+ f'subset of classes {old_classes} in METAINFO.')
+ for i, c in enumerate(old_classes):
+ if c not in new_classes:
+ label_map[i] = 255
+ else:
+ label_map[i] = new_classes.index(c)
+ return label_map
+ else:
+ return None
+
+ def _update_palette(self) -> list:
+ """Update palette after loading metainfo.
+
+ If length of palette is equal to classes, just return the palette.
+ If palette is not defined, it will randomly generate a palette.
+ If classes is updated by customer, it will return the subset of
+ palette.
+
+ Returns:
+ Sequence: Palette for current dataset.
+ """
+ palette = self._metainfo.get('palette', [])
+ classes = self._metainfo.get('classes', [])
+ # palette does match classes
+ if len(palette) == len(classes):
+ return palette
+
+ if len(palette) == 0:
+ # Get random state before set seed, and restore
+ # random state later.
+ # It will prevent loss of randomness, as the palette
+ # may be different in each iteration if not specified.
+ # See: https://github.com/open-mmlab/mmdetection/issues/5844
+ state = np.random.get_state()
+ np.random.seed(42)
+ # random palette
+ new_palette = np.random.randint(
+ 0, 255, size=(len(classes), 3)).tolist()
+ np.random.set_state(state)
+ elif len(palette) >= len(classes) and self.label_map is not None:
+ new_palette = []
+ # return subset of palette
+ for old_id, new_id in sorted(
+ self.label_map.items(), key=lambda x: x[1]):
+ if new_id != 255:
+ new_palette.append(palette[old_id])
+ new_palette = type(palette)(new_palette)
+ else:
+ raise ValueError('palette does not match classes '
+ f'as metainfo is {self._metainfo}.')
+ return new_palette
+
+ def load_data_list(self) -> List[dict]:
+ """Load annotation from directory or annotation file.
+
+ Returns:
+ list[dict]: All data info of dataset.
+ """
+ data_list = []
+ img_dir = self.data_prefix.get('img_path', None)
+ img_dir2 = self.data_prefix.get('img_path2', None)
+ ann_dir = self.data_prefix.get('seg_map_path', None)
+ if osp.isfile(self.ann_file):
+ lines = mmengine.list_from_file(
+ self.ann_file, backend_args=self.backend_args)
+ for line in lines:
+ img_name = line.strip()
+ if '.' in osp.basename(img_name):
+ img_name, img_ext = osp.splitext(img_name)
+ self.img_suffix = img_ext
+ self.img_suffix2 = img_ext
+ data_info = dict(
+ img_path=osp.join(img_dir, img_name + self.img_suffix),
+ img_path2=osp.join(img_dir2, img_name + self.img_suffix2))
+
+ if ann_dir is not None:
+ seg_map = img_name + self.seg_map_suffix
+ data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
+ data_info['label_map'] = self.label_map
+ data_info['reduce_zero_label'] = self.reduce_zero_label
+ data_info['seg_fields'] = []
+ data_list.append(data_info)
+ else:
+ for img in fileio.list_dir_or_file(
+ dir_path=img_dir,
+ list_dir=False,
+ suffix=self.img_suffix,
+ recursive=True,
+ backend_args=self.backend_args):
+ if '.' in osp.basename(img):
+ img, img_ext = osp.splitext(img)
+ self.img_suffix = img_ext
+ self.img_suffix2 = img_ext
+ data_info = dict(
+ img_path=osp.join(img_dir, img + self.img_suffix),
+ img_path2=osp.join(img_dir2, img + self.img_suffix2))
+ if ann_dir is not None:
+ seg_map = img + self.seg_map_suffix
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
data_info['label_map'] = self.label_map
data_info['reduce_zero_label'] = self.reduce_zero_label
diff --git a/mmseg/datasets/bdd100k.py b/mmseg/datasets/bdd100k.py
new file mode 100644
index 00000000000..8ae70b5cb29
--- /dev/null
+++ b/mmseg/datasets/bdd100k.py
@@ -0,0 +1,30 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from mmseg.datasets.basesegdataset import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class BDD100KDataset(BaseSegDataset):
+ METAINFO = dict(
+ classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
+ 'traffic light', 'traffic sign', 'vegetation', 'terrain',
+ 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train',
+ 'motorcycle', 'bicycle'),
+ palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
+ [190, 153, 153], [153, 153, 153], [250, 170,
+ 30], [220, 220, 0],
+ [107, 142, 35], [152, 251, 152], [70, 130, 180],
+ [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
+ [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]])
+
+ def __init__(self,
+ img_suffix='.jpg',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
diff --git a/mmseg/datasets/chase_db1.py b/mmseg/datasets/chase_db1.py
index 5cc1fc56773..626ddf75e9a 100644
--- a/mmseg/datasets/chase_db1.py
+++ b/mmseg/datasets/chase_db1.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import mmengine.fileio as fileio
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset
@@ -27,4 +28,5 @@ def __init__(self,
seg_map_suffix=seg_map_suffix,
reduce_zero_label=reduce_zero_label,
**kwargs)
- assert self.file_client.exists(self.data_prefix['img_path'])
+ assert fileio.exists(
+ self.data_prefix['img_path'], backend_args=self.backend_args)
diff --git a/mmseg/datasets/drive.py b/mmseg/datasets/drive.py
index c42e18e711a..76c0160a6b6 100644
--- a/mmseg/datasets/drive.py
+++ b/mmseg/datasets/drive.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import mmengine.fileio as fileio
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset
@@ -27,4 +28,5 @@ def __init__(self,
seg_map_suffix=seg_map_suffix,
reduce_zero_label=reduce_zero_label,
**kwargs)
- assert self.file_client.exists(self.data_prefix['img_path'])
+ assert fileio.exists(
+ self.data_prefix['img_path'], backend_args=self.backend_args)
diff --git a/mmseg/datasets/dsdl.py b/mmseg/datasets/dsdl.py
new file mode 100644
index 00000000000..bf7e4e61b5f
--- /dev/null
+++ b/mmseg/datasets/dsdl.py
@@ -0,0 +1,116 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+from typing import Dict, List, Optional, Sequence, Union
+
+from mmseg.registry import DATASETS
+from .basesegdataset import BaseSegDataset
+
+try:
+ from dsdl.dataset import DSDLDataset
+except ImportError:
+ DSDLDataset = None
+
+
+@DATASETS.register_module()
+class DSDLSegDataset(BaseSegDataset):
+ """Dataset for dsdl segmentation.
+
+ Args:
+ specific_key_path(dict): Path of specific key which can not
+ be loaded by it's field name.
+ pre_transform(dict): pre-transform functions before loading.
+ used_labels(sequence): list of actual used classes in train steps,
+ this must be subset of class domain.
+ """
+
+ METAINFO = {}
+
+ def __init__(self,
+ specific_key_path: Dict = {},
+ pre_transform: Dict = {},
+ used_labels: Optional[Sequence] = None,
+ **kwargs) -> None:
+
+ if DSDLDataset is None:
+ raise RuntimeError(
+ 'Package dsdl is not installed. Please run "pip install dsdl".'
+ )
+ self.used_labels = used_labels
+
+ loc_config = dict(type='LocalFileReader', working_dir='')
+ if kwargs.get('data_root'):
+ kwargs['ann_file'] = os.path.join(kwargs['data_root'],
+ kwargs['ann_file'])
+ required_fields = ['Image', 'LabelMap']
+
+ self.dsdldataset = DSDLDataset(
+ dsdl_yaml=kwargs['ann_file'],
+ location_config=loc_config,
+ required_fields=required_fields,
+ specific_key_path=specific_key_path,
+ transform=pre_transform,
+ )
+ BaseSegDataset.__init__(self, **kwargs)
+
+ def load_data_list(self) -> List[Dict]:
+ """Load data info from a dsdl yaml file named as ``self.ann_file``
+
+ Returns:
+ List[dict]: A list of data list.
+ """
+
+ if self.used_labels:
+ self._metainfo['classes'] = tuple(self.used_labels)
+ self.label_map = self.get_label_map(self.used_labels)
+ else:
+ self._metainfo['classes'] = tuple(['background'] +
+ self.dsdldataset.class_names)
+ data_list = []
+
+ for i, data in enumerate(self.dsdldataset):
+ datainfo = dict(
+ img_path=os.path.join(self.data_prefix['img_path'],
+ data['Image'][0].location),
+ seg_map_path=os.path.join(self.data_prefix['seg_map_path'],
+ data['LabelMap'][0].location),
+ label_map=self.label_map,
+ reduce_zero_label=self.reduce_zero_label,
+ seg_fields=[],
+ )
+ data_list.append(datainfo)
+
+ return data_list
+
+ def get_label_map(self,
+ new_classes: Optional[Sequence] = None
+ ) -> Union[Dict, None]:
+ """Require label mapping.
+
+ The ``label_map`` is a dictionary, its keys are the old label ids and
+ its values are the new label ids, and is used for changing pixel
+ labels in load_annotations. If and only if old classes in class_dom
+ is not equal to new classes in args and nether of them is not
+ None, `label_map` is not None.
+ Args:
+ new_classes (list, tuple, optional): The new classes name from
+ metainfo. Default to None.
+ Returns:
+ dict, optional: The mapping from old classes to new classes.
+ """
+ old_classes = ['background'] + self.dsdldataset.class_names
+ if (new_classes is not None and old_classes is not None
+ and list(new_classes) != list(old_classes)):
+
+ label_map = {}
+ if not set(new_classes).issubset(old_classes):
+ raise ValueError(
+ f'new classes {new_classes} is not a '
+ f'subset of classes {old_classes} in class_dom.')
+ for i, c in enumerate(old_classes):
+ if c not in new_classes:
+ label_map[i] = 255
+ else:
+ label_map[i] = new_classes.index(c)
+ return label_map
+ else:
+ return None
diff --git a/mmseg/datasets/hrf.py b/mmseg/datasets/hrf.py
index 0df6ccc49c2..fd669cce264 100644
--- a/mmseg/datasets/hrf.py
+++ b/mmseg/datasets/hrf.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import mmengine.fileio as fileio
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset
@@ -27,4 +28,5 @@ def __init__(self,
seg_map_suffix=seg_map_suffix,
reduce_zero_label=reduce_zero_label,
**kwargs)
- assert self.file_client.exists(self.data_prefix['img_path'])
+ assert fileio.exists(
+ self.data_prefix['img_path'], backend_args=self.backend_args)
diff --git a/mmseg/datasets/levir.py b/mmseg/datasets/levir.py
new file mode 100644
index 00000000000..f467481bad7
--- /dev/null
+++ b/mmseg/datasets/levir.py
@@ -0,0 +1,31 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from mmseg.registry import DATASETS
+from .basesegdataset import BaseCDDataset
+
+
+@DATASETS.register_module()
+class LEVIRCDDataset(BaseCDDataset):
+ """ISPRS dataset.
+
+ In segmentation map annotation for ISPRS, 0 is to ignore index.
+ ``reduce_zero_label`` should be set to True. The ``img_suffix`` and
+ ``seg_map_suffix`` are both fixed to '.png'.
+ """
+
+ METAINFO = dict(
+ classes=('background', 'changed'),
+ palette=[[0, 0, 0], [255, 255, 255]])
+
+ def __init__(self,
+ img_suffix='.png',
+ img_suffix2='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ img_suffix2=img_suffix2,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
diff --git a/mmseg/datasets/nyu.py b/mmseg/datasets/nyu.py
new file mode 100644
index 00000000000..fcfda46647d
--- /dev/null
+++ b/mmseg/datasets/nyu.py
@@ -0,0 +1,123 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from typing import List
+
+import mmengine.fileio as fileio
+
+from mmseg.registry import DATASETS
+from .basesegdataset import BaseSegDataset
+
+
+@DATASETS.register_module()
+class NYUDataset(BaseSegDataset):
+ """NYU depth estimation dataset. The file structure should be.
+
+ .. code-block:: none
+
+ ├── data
+ │ ├── nyu
+ │ │ ├── images
+ │ │ │ ├── train
+ │ │ │ │ ├── scene_xxx.jpg
+ │ │ │ │ ├── ...
+ │ │ │ ├── test
+ │ │ ├── annotations
+ │ │ │ ├── train
+ │ │ │ │ ├── scene_xxx.png
+ │ │ │ │ ├── ...
+ │ │ │ ├── test
+
+ Args:
+ ann_file (str): Annotation file path. Defaults to ''.
+ metainfo (dict, optional): Meta information for dataset, such as
+ specify classes to load. Defaults to None.
+ data_root (str, optional): The root directory for ``data_prefix`` and
+ ``ann_file``. Defaults to None.
+ data_prefix (dict, optional): Prefix for training data. Defaults to
+ dict(img_path='images', depth_map_path='annotations').
+ img_suffix (str): Suffix of images. Default: '.jpg'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ filter_cfg (dict, optional): Config for filter data. Defaults to None.
+ indices (int or Sequence[int], optional): Support using first few
+ data in annotation file to facilitate training/testing on a smaller
+ dataset. Defaults to None which means using all ``data_infos``.
+ serialize_data (bool, optional): Whether to hold memory using
+ serialized objects, when enabled, data loader workers can use
+ shared RAM from master process instead of making a copy. Defaults
+ to True.
+ pipeline (list, optional): Processing pipeline. Defaults to [].
+ test_mode (bool, optional): ``test_mode=True`` means in test phase.
+ Defaults to False.
+ lazy_init (bool, optional): Whether to load annotation during
+ instantiation. In some cases, such as visualization, only the meta
+ information of the dataset is needed, which is not necessary to
+ load annotation file. ``Basedataset`` can skip load annotations to
+ save time by set ``lazy_init=True``. Defaults to False.
+ max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
+ None img. The maximum extra number of cycles to get a valid
+ image. Defaults to 1000.
+ ignore_index (int): The label index to be ignored. Default: 255
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
+ Default to False.
+ backend_args (dict, Optional): Arguments to instantiate a file backend.
+ See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
+ for details. Defaults to None.
+ Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
+ """
+ METAINFO = dict(
+ classes=('printer_room', 'bathroom', 'living_room', 'study',
+ 'conference_room', 'study_room', 'kitchen', 'home_office',
+ 'bedroom', 'dinette', 'playroom', 'indoor_balcony',
+ 'laundry_room', 'basement', 'excercise_room', 'foyer',
+ 'home_storage', 'cafe', 'furniture_store', 'office_kitchen',
+ 'student_lounge', 'dining_room', 'reception_room',
+ 'computer_lab', 'classroom', 'office', 'bookstore'))
+
+ def __init__(self,
+ data_prefix=dict(
+ img_path='images', depth_map_path='annotations'),
+ img_suffix='.jpg',
+ depth_map_suffix='.png',
+ **kwargs) -> None:
+ super().__init__(
+ data_prefix=data_prefix,
+ img_suffix=img_suffix,
+ seg_map_suffix=depth_map_suffix,
+ **kwargs)
+
+ def _get_category_id_from_filename(self, image_fname: str) -> int:
+ """Retrieve the category ID from the given image filename."""
+ image_fname = osp.basename(image_fname)
+ position = image_fname.find(next(filter(str.isdigit, image_fname)), 0)
+ categoty_name = image_fname[:position - 1]
+ if categoty_name not in self._metainfo['classes']:
+ return -1
+ else:
+ return self._metainfo['classes'].index(categoty_name)
+
+ def load_data_list(self) -> List[dict]:
+ """Load annotation from directory or annotation file.
+
+ Returns:
+ list[dict]: All data info of dataset.
+ """
+ data_list = []
+ img_dir = self.data_prefix.get('img_path', None)
+ ann_dir = self.data_prefix.get('depth_map_path', None)
+
+ _suffix_len = len(self.img_suffix)
+ for img in fileio.list_dir_or_file(
+ dir_path=img_dir,
+ list_dir=False,
+ suffix=self.img_suffix,
+ recursive=True,
+ backend_args=self.backend_args):
+ data_info = dict(img_path=osp.join(img_dir, img))
+ if ann_dir is not None:
+ depth_map = img[:-_suffix_len] + self.seg_map_suffix
+ data_info['depth_map_path'] = osp.join(ann_dir, depth_map)
+ data_info['seg_fields'] = []
+ data_info['category_id'] = self._get_category_id_from_filename(img)
+ data_list.append(data_info)
+ data_list = sorted(data_list, key=lambda x: x['img_path'])
+ return data_list
diff --git a/mmseg/datasets/pascal_context.py b/mmseg/datasets/pascal_context.py
index a6b2fba7b42..82d00a9b308 100644
--- a/mmseg/datasets/pascal_context.py
+++ b/mmseg/datasets/pascal_context.py
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
-import os.path as osp
+import mmengine.fileio as fileio
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset
@@ -46,18 +46,18 @@ class PascalContextDataset(BaseSegDataset):
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]])
def __init__(self,
- ann_file: str,
+ ann_file='',
img_suffix='.jpg',
seg_map_suffix='.png',
+ reduce_zero_label=False,
**kwargs) -> None:
super().__init__(
img_suffix=img_suffix,
seg_map_suffix=seg_map_suffix,
ann_file=ann_file,
- reduce_zero_label=False,
+ reduce_zero_label=reduce_zero_label,
**kwargs)
- assert self.file_client.exists(
- self.data_prefix['img_path']) and osp.isfile(self.ann_file)
+ assert fileio.exists(self.data_prefix['img_path'], self.backend_args)
@DATASETS.register_module()
@@ -66,8 +66,10 @@ class PascalContextDataset59(BaseSegDataset):
In segmentation map annotation for PascalContext, 0 stands for background,
which is included in 60 categories. ``reduce_zero_label`` is fixed to
- False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
+ True. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
fixed to '.png'.
+ Noted: If the background is 255 and the ids of categories are from 0 to 58,
+ ``reduce_zero_label`` needs to be set to False.
Args:
ann_file (str): Annotation file path.
@@ -100,7 +102,7 @@ class PascalContextDataset59(BaseSegDataset):
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]])
def __init__(self,
- ann_file: str,
+ ann_file='',
img_suffix='.jpg',
seg_map_suffix='.png',
reduce_zero_label=True,
@@ -111,5 +113,4 @@ def __init__(self,
ann_file=ann_file,
reduce_zero_label=reduce_zero_label,
**kwargs)
- assert self.file_client.exists(
- self.data_prefix['img_path']) and osp.isfile(self.ann_file)
+ assert fileio.exists(self.data_prefix['img_path'], self.backend_args)
diff --git a/mmseg/datasets/stare.py b/mmseg/datasets/stare.py
index 2bfce234494..1b997bb785f 100644
--- a/mmseg/datasets/stare.py
+++ b/mmseg/datasets/stare.py
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import mmengine.fileio as fileio
+
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset
@@ -26,4 +28,5 @@ def __init__(self,
seg_map_suffix=seg_map_suffix,
reduce_zero_label=reduce_zero_label,
**kwargs)
- assert self.file_client.exists(self.data_prefix['img_path'])
+ assert fileio.exists(
+ self.data_prefix['img_path'], backend_args=self.backend_args)
diff --git a/mmseg/datasets/transforms/__init__.py b/mmseg/datasets/transforms/__init__.py
index 25f4ee4a987..125f0708181 100644
--- a/mmseg/datasets/transforms/__init__.py
+++ b/mmseg/datasets/transforms/__init__.py
@@ -2,14 +2,16 @@
from .formatting import PackSegInputs
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
LoadBiomedicalData, LoadBiomedicalImageFromFile,
- LoadImageFromNDArray)
+ LoadDepthAnnotation, LoadImageFromNDArray,
+ LoadMultipleRSImageFromFile, LoadSingleRSImageFromFile)
# yapf: disable
-from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,
+from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
- BioMedicalRandomGamma, GenerateEdge,
+ BioMedicalRandomGamma, ConcatCDInput, GenerateEdge,
PhotoMetricDistortion, RandomCrop, RandomCutOut,
- RandomMosaic, RandomRotate, RandomRotFlip, Rerange,
+ RandomDepthMix, RandomFlip, RandomMosaic,
+ RandomRotate, RandomRotFlip, Rerange, Resize,
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
SegRescale)
@@ -22,5 +24,7 @@
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad',
- 'RandomRotFlip'
+ 'RandomRotFlip', 'Albu', 'LoadSingleRSImageFromFile', 'ConcatCDInput',
+ 'LoadMultipleRSImageFromFile', 'LoadDepthAnnotation', 'RandomDepthMix',
+ 'RandomFlip', 'Resize'
]
diff --git a/mmseg/datasets/transforms/formatting.py b/mmseg/datasets/transforms/formatting.py
index 89fd8837913..bd250551e98 100644
--- a/mmseg/datasets/transforms/formatting.py
+++ b/mmseg/datasets/transforms/formatting.py
@@ -92,6 +92,11 @@ def transform(self, results: dict) -> dict:
...].astype(np.int64)))
data_sample.set_data(dict(gt_edge_map=PixelData(**gt_edge_data)))
+ if 'gt_depth_map' in results:
+ gt_depth_data = dict(
+ data=to_tensor(results['gt_depth_map'][None, ...]))
+ data_sample.set_data(dict(gt_depth_map=PixelData(**gt_depth_data)))
+
img_meta = {}
for key in self.meta_keys:
if key in results:
diff --git a/mmseg/datasets/transforms/loading.py b/mmseg/datasets/transforms/loading.py
index d2e93b1abb9..438b5527f08 100644
--- a/mmseg/datasets/transforms/loading.py
+++ b/mmseg/datasets/transforms/loading.py
@@ -12,6 +12,11 @@
from mmseg.registry import TRANSFORMS
from mmseg.utils import datafrombytes
+try:
+ from osgeo import gdal
+except ImportError:
+ gdal = None
+
@TRANSFORMS.register_module()
class LoadAnnotations(MMCV_LoadAnnotations):
@@ -493,3 +498,207 @@ def transform(self, single_input: Union[str, np.ndarray, dict]) -> dict:
if 'img' in inputs:
return self.from_ndarray(inputs)
return self.from_file(inputs)
+
+
+@TRANSFORMS.register_module()
+class LoadSingleRSImageFromFile(BaseTransform):
+ """Load a Remote Sensing mage from file.
+
+ Required Keys:
+
+ - img_path
+
+ Modified Keys:
+
+ - img
+ - img_shape
+ - ori_shape
+
+ Args:
+ to_float32 (bool): Whether to convert the loaded image to a float32
+ numpy array. If set to False, the loaded image is a float64 array.
+ Defaults to True.
+ """
+
+ def __init__(self, to_float32: bool = True):
+ self.to_float32 = to_float32
+
+ if gdal is None:
+ raise RuntimeError('gdal is not installed')
+
+ def transform(self, results: Dict) -> Dict:
+ """Functions to load image.
+
+ Args:
+ results (dict): Result dict from :obj:``mmcv.BaseDataset``.
+
+ Returns:
+ dict: The dict contains loaded image and meta information.
+ """
+
+ filename = results['img_path']
+ ds = gdal.Open(filename)
+ if ds is None:
+ raise Exception(f'Unable to open file: {filename}')
+ img = np.einsum('ijk->jki', ds.ReadAsArray())
+
+ if self.to_float32:
+ img = img.astype(np.float32)
+
+ results['img'] = img
+ results['img_shape'] = img.shape[:2]
+ results['ori_shape'] = img.shape[:2]
+ return results
+
+ def __repr__(self):
+ repr_str = (f'{self.__class__.__name__}('
+ f'to_float32={self.to_float32})')
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class LoadMultipleRSImageFromFile(BaseTransform):
+ """Load two Remote Sensing mage from file.
+
+ Required Keys:
+
+ - img_path
+ - img_path2
+
+ Modified Keys:
+
+ - img
+ - img2
+ - img_shape
+ - ori_shape
+
+ Args:
+ to_float32 (bool): Whether to convert the loaded image to a float32
+ numpy array. If set to False, the loaded image is a float64 array.
+ Defaults to True.
+ """
+
+ def __init__(self, to_float32: bool = True):
+ if gdal is None:
+ raise RuntimeError('gdal is not installed')
+ self.to_float32 = to_float32
+
+ def transform(self, results: Dict) -> Dict:
+ """Functions to load image.
+
+ Args:
+ results (dict): Result dict from :obj:``mmcv.BaseDataset``.
+
+ Returns:
+ dict: The dict contains loaded image and meta information.
+ """
+
+ filename = results['img_path']
+ filename2 = results['img_path2']
+
+ ds = gdal.Open(filename)
+ ds2 = gdal.Open(filename2)
+
+ if ds is None:
+ raise Exception(f'Unable to open file: {filename}')
+ if ds2 is None:
+ raise Exception(f'Unable to open file: {filename2}')
+
+ img = np.einsum('ijk->jki', ds.ReadAsArray())
+ img2 = np.einsum('ijk->jki', ds2.ReadAsArray())
+
+ if self.to_float32:
+ img = img.astype(np.float32)
+ img2 = img2.astype(np.float32)
+
+ if img.shape != img2.shape:
+ raise Exception(f'Image shapes do not match:'
+ f' {img.shape} vs {img2.shape}')
+
+ results['img'] = img
+ results['img2'] = img2
+ results['img_shape'] = img.shape[:2]
+ results['ori_shape'] = img.shape[:2]
+ return results
+
+ def __repr__(self):
+ repr_str = (f'{self.__class__.__name__}('
+ f'to_float32={self.to_float32})')
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class LoadDepthAnnotation(BaseTransform):
+ """Load ``depth_map`` annotation provided by depth estimation dataset.
+
+ The annotation format is as the following:
+
+ .. code-block:: python
+
+ {
+ 'gt_depth_map': np.ndarray [Y, X]
+ }
+
+ Required Keys:
+
+ - seg_depth_path
+
+ Added Keys:
+
+ - gt_depth_map (np.ndarray): Depth map with shape (Y, X) by
+ default, and data type is float32 if set to_float32 = True.
+ - depth_rescale_factor (float): The rescale factor of depth map, which
+ can be used to recover the original value of depth map.
+
+ Args:
+ decode_backend (str): The data decoding backend type. Options are
+ 'numpy', 'nifti', and 'cv2'. Defaults to 'cv2'.
+ to_float32 (bool): Whether to convert the loaded depth map to a float32
+ numpy array. If set to False, the loaded image is an uint16 array.
+ Defaults to True.
+ depth_rescale_factor (float): Factor to rescale the depth value to
+ limit the range. Defaults to 1.0.
+ backend_args (dict, Optional): Arguments to instantiate a file backend.
+ See :class:`mmengine.fileio` for details.
+ Defaults to None.
+ Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
+ """
+
+ def __init__(self,
+ decode_backend: str = 'cv2',
+ to_float32: bool = True,
+ depth_rescale_factor: float = 1.0,
+ backend_args: Optional[dict] = None) -> None:
+ super().__init__()
+ self.decode_backend = decode_backend
+ self.to_float32 = to_float32
+ self.depth_rescale_factor = depth_rescale_factor
+ self.backend_args = backend_args.copy() if backend_args else None
+
+ def transform(self, results: Dict) -> Dict:
+ """Functions to load depth map.
+
+ Args:
+ results (dict): Result dict from :obj:``mmcv.BaseDataset``.
+
+ Returns:
+ dict: The dict contains loaded depth map.
+ """
+ data_bytes = fileio.get(results['depth_map_path'], self.backend_args)
+ gt_depth_map = datafrombytes(data_bytes, backend=self.decode_backend)
+
+ if self.to_float32:
+ gt_depth_map = gt_depth_map.astype(np.float32)
+
+ gt_depth_map *= self.depth_rescale_factor
+ results['gt_depth_map'] = gt_depth_map
+ results['seg_fields'].append('gt_depth_map')
+ results['depth_rescale_factor'] = self.depth_rescale_factor
+ return results
+
+ def __repr__(self):
+ repr_str = (f'{self.__class__.__name__}('
+ f"decode_backend='{self.decode_backend}', "
+ f'to_float32={self.to_float32}, '
+ f'backend_args={self.backend_args})')
+ return repr_str
diff --git a/mmseg/datasets/transforms/transforms.py b/mmseg/datasets/transforms/transforms.py
index fb7e2a0e665..082ae5b4401 100644
--- a/mmseg/datasets/transforms/transforms.py
+++ b/mmseg/datasets/transforms/transforms.py
@@ -1,11 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
+import inspect
import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union
import cv2
import mmcv
+import mmengine
import numpy as np
+from mmcv.transforms import RandomFlip as MMCV_RandomFlip
+from mmcv.transforms import Resize as MMCV_Resize
from mmcv.transforms.base import BaseTransform
from mmcv.transforms.utils import cache_randomness
from mmengine.utils import is_tuple_of
@@ -15,6 +19,15 @@
from mmseg.datasets.dataset_wrappers import MultiImageMixDataset
from mmseg.registry import TRANSFORMS
+try:
+ import albumentations
+ from albumentations import Compose
+ ALBU_INSTALLED = True
+except ImportError:
+ albumentations = None
+ Compose = None
+ ALBU_INSTALLED = False
+
@TRANSFORMS.register_module()
class ResizeToMultiple(BaseTransform):
@@ -186,7 +199,7 @@ def transform(self, results: dict) -> dict:
def __repr__(self):
repr_str = self.__class__.__name__
- repr_str += f'(clip_limit={self.clip_limit}, '\
+ repr_str += f'(clip_limit={self.clip_limit}, ' \
f'tile_grid_size={self.tile_grid_size})'
return repr_str
@@ -939,6 +952,152 @@ def __repr__(self):
return repr_str
+@TRANSFORMS.register_module()
+class RandomFlip(MMCV_RandomFlip):
+ """Flip the image & bbox & segmentation map. Added or Updated
+ keys: flip, flip_direction, img, gt_bboxes, gt_seg_map, and gt_depth_map.
+ There are 3 flip modes:
+
+ - ``prob`` is float, ``direction`` is string: the image will be
+ ``direction``ly flipped with probability of ``prob`` .
+ E.g., ``prob=0.5``, ``direction='horizontal'``,
+ then image will be horizontally flipped with probability of 0.5.
+
+ - ``prob`` is float, ``direction`` is list of string: the image will
+ be ``direction[i]``ly flipped with probability of
+ ``prob/len(direction)``.
+ E.g., ``prob=0.5``, ``direction=['horizontal', 'vertical']``,
+ then image will be horizontally flipped with probability of 0.25,
+ vertically with probability of 0.25.
+
+ - ``prob`` is list of float, ``direction`` is list of string:
+ given ``len(prob) == len(direction)``, the image will
+ be ``direction[i]``ly flipped with probability of ``prob[i]``.
+ E.g., ``prob=[0.3, 0.5]``, ``direction=['horizontal',
+ 'vertical']``, then image will be horizontally flipped with
+ probability of 0.3, vertically with probability of 0.5.
+
+ Required Keys:
+
+ - img
+ - gt_bboxes (optional)
+ - gt_seg_map (optional)
+ - gt_depth_map (optional)
+
+ Modified Keys:
+
+ - img
+ - gt_bboxes (optional)
+ - gt_seg_map (optional)
+ - gt_depth_map (optional)
+
+ Added Keys:
+
+ - flip
+ - flip_direction
+ - swap_seg_labels (optional)
+
+ Args:
+ prob (float | list[float], optional): The flipping probability.
+ Defaults to None.
+ direction(str | list[str]): The flipping direction. Options
+ If input is a list, the length must equal ``prob``. Each
+ element in ``prob`` indicates the flip probability of
+ corresponding direction. Defaults to 'horizontal'.
+ swap_seg_labels (list, optional): The label pair need to be swapped
+ for ground truth, like 'left arm' and 'right arm' need to be
+ swapped after horizontal flipping. For example, ``[(1, 5)]``,
+ where 1/5 is the label of the left/right arm. Defaults to None.
+ """
+
+ def _flip(self, results: dict) -> None:
+ """Flip images, bounding boxes and semantic segmentation map."""
+ # flip image
+ results['img'] = mmcv.imflip(
+ results['img'], direction=results['flip_direction'])
+
+ img_shape = results['img'].shape[:2]
+
+ # flip bboxes
+ if results.get('gt_bboxes', None) is not None:
+ results['gt_bboxes'] = self._flip_bbox(results['gt_bboxes'],
+ img_shape,
+ results['flip_direction'])
+
+ # flip seg map
+ for key in results.get('seg_fields', []):
+ if results.get(key, None) is not None:
+ results[key] = self._flip_seg_map(
+ results[key], direction=results['flip_direction']).copy()
+ results['swap_seg_labels'] = self.swap_seg_labels
+
+
+@TRANSFORMS.register_module()
+class Resize(MMCV_Resize):
+ """Resize images & seg & depth map.
+
+ This transform resizes the input image according to ``scale`` or
+ ``scale_factor``. Seg map, depth map and other relative annotations are
+ then resized with the same scale factor.
+ if ``scale`` and ``scale_factor`` are both set, it will use ``scale`` to
+ resize.
+
+ Required Keys:
+
+ - img
+ - gt_seg_map (optional)
+ - gt_depth_map (optional)
+
+ Modified Keys:
+
+ - img
+ - gt_seg_map
+ - gt_depth_map
+
+ Added Keys:
+
+ - scale
+ - scale_factor
+ - keep_ratio
+
+ Args:
+ scale (int or tuple): Images scales for resizing. Defaults to None
+ scale_factor (float or tuple[float]): Scale factors for resizing.
+ Defaults to None.
+ keep_ratio (bool): Whether to keep the aspect ratio when resizing the
+ image. Defaults to False.
+ clip_object_border (bool): Whether to clip the objects
+ outside the border of the image. In some dataset like MOT17, the gt
+ bboxes are allowed to cross the border of images. Therefore, we
+ don't need to clip the gt bboxes in these cases. Defaults to True.
+ backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
+ These two backends generates slightly different results. Defaults
+ to 'cv2'.
+ interpolation (str): Interpolation method, accepted values are
+ "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
+ backend, "nearest", "bilinear" for 'pillow' backend. Defaults
+ to 'bilinear'.
+ """
+
+ def _resize_seg(self, results: dict) -> None:
+ """Resize semantic segmentation map with ``results['scale']``."""
+ for seg_key in results.get('seg_fields', []):
+ if results.get(seg_key, None) is not None:
+ if self.keep_ratio:
+ gt_seg = mmcv.imrescale(
+ results[seg_key],
+ results['scale'],
+ interpolation='nearest',
+ backend=self.backend)
+ else:
+ gt_seg = mmcv.imresize(
+ results[seg_key],
+ results['scale'],
+ interpolation='nearest',
+ backend=self.backend)
+ results[seg_key] = gt_seg
+
+
@TRANSFORMS.register_module()
class RandomMosaic(BaseTransform):
"""Mosaic augmentation. Given 4 images, mosaic transform combines them into
@@ -1151,8 +1310,8 @@ def _mosaic_transform_seg(self, results: dict) -> dict:
x1_c, y1_c, x2_c, y2_c = crop_coord
# crop and paste image
- mosaic_seg[y1_p:y2_p, x1_p:x2_p] = gt_seg_i[y1_c:y2_c,
- x1_c:x2_c]
+ mosaic_seg[y1_p:y2_p, x1_p:x2_p] = \
+ gt_seg_i[y1_c:y2_c, x1_c:x2_c]
results[key] = mosaic_seg
@@ -1760,9 +1919,9 @@ def __repr__(self):
repr_str += f'(prob={self.prob}, '
repr_str += f'prob_per_channel={self.prob_per_channel}, '
repr_str += f'sigma_range={self.sigma_range}, '
- repr_str += 'different_sigma_per_channel='\
+ repr_str += 'different_sigma_per_channel=' \
f'{self.different_sigma_per_channel}, '
- repr_str += 'different_sigma_per_axis='\
+ repr_str += 'different_sigma_per_axis=' \
f'{self.different_sigma_per_axis})'
return repr_str
@@ -2135,3 +2294,221 @@ def __repr__(self):
repr_str += f'(prob={self.prob}, axes={self.axes}, ' \
f'swap_label_pairs={self.swap_label_pairs})'
return repr_str
+
+
+@TRANSFORMS.register_module()
+class Albu(BaseTransform):
+ """Albumentation augmentation. Adds custom transformations from
+ Albumentations library. Please, visit
+ `https://albumentations.readthedocs.io` to get more information. An example
+ of ``transforms`` is as followed:
+
+ .. code-block::
+ [
+ dict(
+ type='ShiftScaleRotate',
+ shift_limit=0.0625,
+ scale_limit=0.0,
+ rotate_limit=0,
+ interpolation=1,
+ p=0.5),
+ dict(
+ type='RandomBrightnessContrast',
+ brightness_limit=[0.1, 0.3],
+ contrast_limit=[0.1, 0.3],
+ p=0.2),
+ dict(type='ChannelShuffle', p=0.1),
+ dict(
+ type='OneOf',
+ transforms=[
+ dict(type='Blur', blur_limit=3, p=1.0),
+ dict(type='MedianBlur', blur_limit=3, p=1.0)
+ ],
+ p=0.1),
+ ]
+ Args:
+ transforms (list[dict]): A list of albu transformations
+ keymap (dict): Contains {'input key':'albumentation-style key'}
+ update_pad_shape (bool): Whether to update padding shape according to \
+ the output shape of the last transform
+ """
+
+ def __init__(self,
+ transforms: List[dict],
+ keymap: Optional[dict] = None,
+ update_pad_shape: bool = False):
+ if not ALBU_INSTALLED:
+ raise ImportError(
+ 'albumentations is not installed, '
+ 'we suggest install albumentation by '
+ '"pip install albumentations>=0.3.2 --no-binary qudida,albumentations"' # noqa
+ )
+
+ # Args will be modified later, copying it will be safer
+ transforms = copy.deepcopy(transforms)
+
+ self.transforms = transforms
+ self.keymap = keymap
+ self.update_pad_shape = update_pad_shape
+
+ self.aug = Compose([self.albu_builder(t) for t in self.transforms])
+
+ if not keymap:
+ self.keymap_to_albu = {'img': 'image', 'gt_seg_map': 'mask'}
+ else:
+ self.keymap_to_albu = copy.deepcopy(keymap)
+ self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}
+
+ def albu_builder(self, cfg: dict) -> object:
+ """Build a callable object from a dict containing albu arguments.
+
+ Args:
+ cfg (dict): Config dict. It should at least contain the key "type".
+
+ Returns:
+ Callable: A callable object.
+ """
+
+ assert isinstance(cfg, dict) and 'type' in cfg
+ args = cfg.copy()
+
+ obj_type = args.pop('type')
+ if mmengine.is_str(obj_type):
+ if not ALBU_INSTALLED:
+ raise ImportError(
+ 'albumentations is not installed, '
+ 'we suggest install albumentation by '
+ '"pip install albumentations>=0.3.2 --no-binary qudida,albumentations"' # noqa
+ )
+ obj_cls = getattr(albumentations, obj_type)
+ elif inspect.isclass(obj_type):
+ obj_cls = obj_type
+ else:
+ raise TypeError(
+ f'type must be a valid type or str, but got {type(obj_type)}')
+
+ if 'transforms' in args:
+ args['transforms'] = [
+ self.albu_builder(t) for t in args['transforms']
+ ]
+
+ return obj_cls(**args)
+
+ @staticmethod
+ def mapper(d: dict, keymap: dict):
+ """Dictionary mapper.
+
+ Renames keys according to keymap provided.
+ Args:
+ d (dict): old dict
+ keymap (dict): {'old_key':'new_key'}
+ Returns:
+ dict: new dict.
+ """
+
+ updated_dict = {}
+ for k, _ in zip(d.keys(), d.values()):
+ new_k = keymap.get(k, k)
+ updated_dict[new_k] = d[k]
+ return updated_dict
+
+ def transform(self, results):
+ # dict to albumentations format
+ results = self.mapper(results, self.keymap_to_albu)
+
+ # Convert to RGB since Albumentations works with RGB images
+ results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_BGR2RGB)
+
+ results = self.aug(**results)
+
+ # Convert back to BGR
+ results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_RGB2BGR)
+
+ # back to the original format
+ results = self.mapper(results, self.keymap_back)
+
+ # update final shape
+ if self.update_pad_shape:
+ results['pad_shape'] = results['img'].shape
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__ + f'(transforms={self.transforms})'
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class ConcatCDInput(BaseTransform):
+ """Concat images for change detection.
+
+ Required Keys:
+
+ - img
+ - img2
+
+ Args:
+ input_keys (tuple): Input image keys for change detection.
+ Default: ('img', 'img2').
+ """
+
+ def __init__(self, input_keys=('img', 'img2')):
+ self.input_keys = input_keys
+
+ def transform(self, results: dict) -> dict:
+ img = []
+ for input_key in self.input_keys:
+ img.append(results.pop(input_key))
+ results['img'] = np.concatenate(img, axis=2)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(input_keys={self.input_keys}, '
+ return repr_str
+
+
+@TRANSFORMS.register_module()
+class RandomDepthMix(BaseTransform):
+ """This class implements the RandomDepthMix transform.
+
+ Args:
+ prob (float): Probability of applying the transformation.
+ Defaults to 0.25.
+ mix_scale_ratio (float): Ratio to scale the mix width.
+ Defaults to 0.75.
+ """
+
+ def __init__(
+ self,
+ prob: float = 0.25,
+ mix_scale_ratio: float = 0.75,
+ ):
+ super().__init__()
+
+ self.prob = prob
+ self.mix_scale_ratio = mix_scale_ratio
+
+ def transform(self, results: dict) -> dict:
+ if random.random() > self.prob:
+ return results
+
+ h, w = results['img_shape'][:2]
+ left = int(w * random.random())
+ width_ratio = self.mix_scale_ratio * random.random()
+ width = int(max(1, (w - left) * width_ratio))
+
+ img = results['img']
+ depth_rescale_factor = results.get('depth_rescale_factor', 1)
+ depth_map = results['gt_depth_map'] / depth_rescale_factor
+
+ if img.ndim == 3:
+ for c in range(img.shape[-1]):
+ img[:, left:left + width, c] = depth_map[:, left:left + width]
+ elif img.ndim == 2:
+ img[:, left:left + width] = depth_map[:, left:left + width]
+ else:
+ raise ValueError(f'Invalid image shape ({img.shape})')
+
+ results['img'] = img
+ return results
diff --git a/mmseg/engine/__init__.py b/mmseg/engine/__init__.py
index ada40570121..98139a0047f 100644
--- a/mmseg/engine/__init__.py
+++ b/mmseg/engine/__init__.py
@@ -1,9 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .hooks import SegVisualizationHook
-from .optimizers import (LayerDecayOptimizerConstructor,
+from .optimizers import (ForceDefaultOptimWrapperConstructor,
+ LayerDecayOptimizerConstructor,
LearningRateDecayOptimizerConstructor)
+from .schedulers import PolyLRRatio
__all__ = [
'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor',
- 'SegVisualizationHook'
+ 'SegVisualizationHook', 'PolyLRRatio',
+ 'ForceDefaultOptimWrapperConstructor'
]
diff --git a/mmseg/engine/optimizers/__init__.py b/mmseg/engine/optimizers/__init__.py
index 4fbf4ecfcd4..e4cf58741fe 100644
--- a/mmseg/engine/optimizers/__init__.py
+++ b/mmseg/engine/optimizers/__init__.py
@@ -1,7 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from .force_default_constructor import ForceDefaultOptimWrapperConstructor
from .layer_decay_optimizer_constructor import (
LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor)
__all__ = [
- 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor'
+ 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor',
+ 'ForceDefaultOptimWrapperConstructor'
]
diff --git a/mmseg/engine/optimizers/force_default_constructor.py b/mmseg/engine/optimizers/force_default_constructor.py
new file mode 100644
index 00000000000..12c642ad411
--- /dev/null
+++ b/mmseg/engine/optimizers/force_default_constructor.py
@@ -0,0 +1,255 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+from typing import List, Optional, Union
+
+import torch
+import torch.nn as nn
+from mmengine.logging import print_log
+from mmengine.optim import DefaultOptimWrapperConstructor
+from mmengine.utils.dl_utils import mmcv_full_available
+from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
+from torch.nn import GroupNorm, LayerNorm
+
+from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS
+
+
+@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
+class ForceDefaultOptimWrapperConstructor(DefaultOptimWrapperConstructor):
+ """Default constructor with forced optimizer settings.
+
+ This constructor extends the default constructor to add an option for
+ forcing default optimizer settings. This is useful for ensuring that
+ certain parameters or layers strictly adhere to pre-defined default
+ settings, regardless of any custom settings specified.
+
+ By default, each parameter share the same optimizer settings, and we
+ provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
+ It is a dict and may contain various fields like 'custom_keys',
+ 'bias_lr_mult', etc., as well as the additional field
+ `force_default_settings` which allows for enforcing default settings on
+ optimizer parameters.
+
+ - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
+ one of the keys in ``custom_keys`` is a substring of the name of one
+ parameter, then the setting of the parameter will be specified by
+ ``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will
+ be ignored. It should be noted that the aforementioned ``key`` is the
+ longest key that is a substring of the name of the parameter. If there
+ are multiple matched keys with the same length, then the key with lower
+ alphabet order will be chosen.
+ ``custom_keys[key]`` should be a dict and may contain fields ``lr_mult``
+ and ``decay_mult``. See Example 2 below.
+ - ``bias_lr_mult`` (float): It will be multiplied to the learning
+ rate for all bias parameters (except for those in normalization
+ layers and offset layers of DCN).
+ - ``bias_decay_mult`` (float): It will be multiplied to the weight
+ decay for all bias parameters (except for those in
+ normalization layers, depthwise conv layers, offset layers of DCN).
+ - ``norm_decay_mult`` (float): It will be multiplied to the weight
+ decay for all weight and bias parameters of normalization
+ layers.
+ - ``flat_decay_mult`` (float): It will be multiplied to the weight
+ decay for all one-dimensional parameters
+ - ``dwconv_decay_mult`` (float): It will be multiplied to the weight
+ decay for all weight and bias parameters of depthwise conv
+ layers.
+ - ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning
+ rate for parameters of offset layer in the deformable convs
+ of a model.
+ - ``bypass_duplicate`` (bool): If true, the duplicate parameters
+ would not be added into optimizer. Defaults to False.
+ - ``force_default_settings`` (bool): If true, this will override any
+ custom settings defined by ``custom_keys`` and enforce the use of
+ default settings for optimizer parameters like ``bias_lr_mult``.
+ This is particularly useful when you want to ensure that certain layers
+ or parameters adhere strictly to the pre-defined default settings.
+
+ Note:
+
+ 1. If the option ``dcn_offset_lr_mult`` is used, the constructor will
+ override the effect of ``bias_lr_mult`` in the bias of offset layer.
+ So be careful when using both ``bias_lr_mult`` and
+ ``dcn_offset_lr_mult``. If you wish to apply both of them to the offset
+ layer in deformable convs, set ``dcn_offset_lr_mult`` to the original
+ ``dcn_offset_lr_mult`` * ``bias_lr_mult``.
+
+ 2. If the option ``dcn_offset_lr_mult`` is used, the constructor will
+ apply it to all the DCN layers in the model. So be careful when the
+ model contains multiple DCN layers in places other than backbone.
+
+ 3. When the option ``force_default_settings`` is true, it will override
+ any custom settings provided in ``custom_keys``. This ensures that the
+ default settings for the optimizer parameters are used.
+
+ Args:
+ optim_wrapper_cfg (dict): The config dict of the optimizer wrapper.
+
+ Required fields of ``optim_wrapper_cfg`` are
+
+ - ``type``: class name of the OptimizerWrapper
+ - ``optimizer``: The configuration of optimizer.
+
+ Optional fields of ``optim_wrapper_cfg`` are
+
+ - any arguments of the corresponding optimizer wrapper type,
+ e.g., accumulative_counts, clip_grad, etc.
+
+ Required fields of ``optimizer`` are
+
+ - `type`: class name of the optimizer.
+
+ Optional fields of ``optimizer`` are
+
+ - any arguments of the corresponding optimizer type, e.g.,
+ lr, weight_decay, momentum, etc.
+
+ paramwise_cfg (dict, optional): Parameter-wise options.
+
+ Example 1:
+ >>> model = torch.nn.modules.Conv1d(1, 1, 1)
+ >>> optim_wrapper_cfg = dict(
+ >>> dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01,
+ >>> momentum=0.9, weight_decay=0.0001))
+ >>> paramwise_cfg = dict(norm_decay_mult=0.)
+ >>> optim_wrapper_builder = DefaultOptimWrapperConstructor(
+ >>> optim_wrapper_cfg, paramwise_cfg)
+ >>> optim_wrapper = optim_wrapper_builder(model)
+
+ Example 2:
+ >>> # assume model have attribute model.backbone and model.cls_head
+ >>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict(
+ >>> type='SGD', lr=0.01, weight_decay=0.95))
+ >>> paramwise_cfg = dict(custom_keys={
+ >>> 'backbone': dict(lr_mult=0.1, decay_mult=0.9)})
+ >>> optim_wrapper_builder = DefaultOptimWrapperConstructor(
+ >>> optim_wrapper_cfg, paramwise_cfg)
+ >>> optim_wrapper = optim_wrapper_builder(model)
+ >>> # Then the `lr` and `weight_decay` for model.backbone is
+ >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for
+ >>> # model.cls_head is (0.01, 0.95).
+ """
+
+ def add_params(self,
+ params: List[dict],
+ module: nn.Module,
+ prefix: str = '',
+ is_dcn_module: Optional[Union[int, float]] = None) -> None:
+ """Add all parameters of module to the params list.
+
+ The parameters of the given module will be added to the list of param
+ groups, with specific rules defined by paramwise_cfg.
+
+ Args:
+ params (list[dict]): A list of param groups, it will be modified
+ in place.
+ module (nn.Module): The module to be added.
+ prefix (str): The prefix of the module
+ is_dcn_module (int|float|None): If the current module is a
+ submodule of DCN, `is_dcn_module` will be passed to
+ control conv_offset layer's learning rate. Defaults to None.
+ """
+ # get param-wise options
+ custom_keys = self.paramwise_cfg.get('custom_keys', {})
+ # first sort with alphabet order and then sort with reversed len of str
+ sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
+
+ bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', None)
+ bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None)
+ norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None)
+ dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', None)
+ flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None)
+ bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
+ dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', None)
+ force_default_settings = self.paramwise_cfg.get(
+ 'force_default_settings', False)
+
+ # special rules for norm layers and depth-wise conv layers
+ is_norm = isinstance(module,
+ (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
+ is_dwconv = (
+ isinstance(module, torch.nn.Conv2d)
+ and module.in_channels == module.groups)
+
+ for name, param in module.named_parameters(recurse=False):
+ param_group = {'params': [param]}
+ if bypass_duplicate and self._is_in(param_group, params):
+ print_log(
+ f'{prefix} is duplicate. It is skipped since '
+ f'bypass_duplicate={bypass_duplicate}',
+ logger='current',
+ level=logging.WARNING)
+ continue
+ if not param.requires_grad:
+ params.append(param_group)
+ continue
+
+ # if the parameter match one of the custom keys, ignore other rules
+ is_custom = False
+ for key in sorted_keys:
+ if key in f'{prefix}.{name}':
+ is_custom = True
+ lr_mult = custom_keys[key].get('lr_mult', 1.)
+ param_group['lr'] = self.base_lr * lr_mult
+ if self.base_wd is not None:
+ decay_mult = custom_keys[key].get('decay_mult', 1.)
+ param_group['weight_decay'] = self.base_wd * decay_mult
+ # add custom settings to param_group
+ for k, v in custom_keys[key].items():
+ param_group[k] = v
+ break
+
+ if not is_custom or force_default_settings:
+ # bias_lr_mult affects all bias parameters
+ # except for norm.bias dcn.conv_offset.bias
+ if name == 'bias' and not (
+ is_norm or is_dcn_module) and bias_lr_mult is not None:
+ param_group['lr'] = self.base_lr * bias_lr_mult
+
+ if (prefix.find('conv_offset') != -1 and is_dcn_module
+ and dcn_offset_lr_mult is not None
+ and isinstance(module, torch.nn.Conv2d)):
+ # deal with both dcn_offset's bias & weight
+ param_group['lr'] = self.base_lr * dcn_offset_lr_mult
+
+ # apply weight decay policies
+ if self.base_wd is not None:
+ # norm decay
+ if is_norm and norm_decay_mult is not None:
+ param_group[
+ 'weight_decay'] = self.base_wd * norm_decay_mult
+ # bias lr and decay
+ elif (name == 'bias' and not is_dcn_module
+ and bias_decay_mult is not None):
+ param_group[
+ 'weight_decay'] = self.base_wd * bias_decay_mult
+ # depth-wise conv
+ elif is_dwconv and dwconv_decay_mult is not None:
+ param_group[
+ 'weight_decay'] = self.base_wd * dwconv_decay_mult
+ # flatten parameters except dcn offset
+ elif (param.ndim == 1 and not is_dcn_module
+ and flat_decay_mult is not None):
+ param_group[
+ 'weight_decay'] = self.base_wd * flat_decay_mult
+ params.append(param_group)
+ for key, value in param_group.items():
+ if key == 'params':
+ continue
+ full_name = f'{prefix}.{name}' if prefix else name
+ print_log(
+ f'paramwise_options -- {full_name}:{key}={value}',
+ logger='current')
+
+ if mmcv_full_available():
+ from mmcv.ops import DeformConv2d, ModulatedDeformConv2d
+ is_dcn_module = isinstance(module,
+ (DeformConv2d, ModulatedDeformConv2d))
+ else:
+ is_dcn_module = False
+ for child_name, child_mod in module.named_children():
+ child_prefix = f'{prefix}.{child_name}' if prefix else child_name
+ self.add_params(
+ params,
+ child_mod,
+ prefix=child_prefix,
+ is_dcn_module=is_dcn_module)
diff --git a/mmseg/engine/schedulers/__init__.py b/mmseg/engine/schedulers/__init__.py
new file mode 100644
index 00000000000..3cd3f621134
--- /dev/null
+++ b/mmseg/engine/schedulers/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .poly_ratio_scheduler import PolyLRRatio
+
+__all__ = ['PolyLRRatio']
diff --git a/mmseg/engine/schedulers/poly_ratio_scheduler.py b/mmseg/engine/schedulers/poly_ratio_scheduler.py
new file mode 100644
index 00000000000..057203acc9c
--- /dev/null
+++ b/mmseg/engine/schedulers/poly_ratio_scheduler.py
@@ -0,0 +1,62 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional
+
+from mmengine.optim.scheduler import PolyLR
+
+from mmseg.registry import PARAM_SCHEDULERS
+
+
+@PARAM_SCHEDULERS.register_module()
+class PolyLRRatio(PolyLR):
+ """Implements polynomial learning rate decay with ratio.
+
+ This scheduler adjusts the learning rate of each parameter group
+ following a polynomial decay equation. The decay can occur in
+ conjunction with external parameter adjustments made outside this
+ scheduler.
+
+ Args:
+ optimizer (Optimizer or OptimWrapper): Wrapped optimizer.
+ eta_min (float): Minimum learning rate at the end of scheduling.
+ Defaults to 0.
+ eta_min_ratio (float, optional): The ratio of the minimum parameter
+ value to the base parameter value. Either `eta_min` or
+ `eta_min_ratio` should be specified. Defaults to None.
+ power (float): The power of the polynomial. Defaults to 1.0.
+ begin (int): Step at which to start updating the parameters.
+ Defaults to 0.
+ end (int): Step at which to stop updating the parameters.
+ Defaults to INF.
+ last_step (int): The index of last step. Used for resume without
+ state dict. Defaults to -1.
+ by_epoch (bool): Whether the scheduled parameters are updated by
+ epochs. Defaults to True.
+ verbose (bool): Whether to print the value for each update.
+ Defaults to False.
+ """
+
+ def __init__(self, eta_min_ratio: Optional[int] = None, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.eta_min_ratio = eta_min_ratio
+
+ def _get_value(self):
+ """Compute value using chainable form of the scheduler."""
+
+ if self.last_step == 0:
+ return [
+ group[self.param_name] for group in self.optimizer.param_groups
+ ]
+
+ param_groups_value = []
+ for base_value, param_group in zip(self.base_values,
+ self.optimizer.param_groups):
+ eta_min = self.eta_min if self.eta_min_ratio is None else \
+ base_value * self.eta_min_ratio
+ step_ratio = (1 - 1 /
+ (self.total_iters - self.last_step + 1))**self.power
+ step_value = (param_group[self.param_name] -
+ eta_min) * step_ratio + eta_min
+ param_groups_value.append(step_value)
+
+ return param_groups_value
diff --git a/mmseg/evaluation/__init__.py b/mmseg/evaluation/__init__.py
index a82008f3ad3..82b3a8d68d3 100644
--- a/mmseg/evaluation/__init__.py
+++ b/mmseg/evaluation/__init__.py
@@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from .metrics import CityscapesMetric, IoUMetric
+from .metrics import CityscapesMetric, DepthMetric, IoUMetric
-__all__ = ['IoUMetric', 'CityscapesMetric']
+__all__ = ['IoUMetric', 'CityscapesMetric', 'DepthMetric']
diff --git a/mmseg/evaluation/metrics/__init__.py b/mmseg/evaluation/metrics/__init__.py
index 0aa39e480cd..848d4713dc8 100644
--- a/mmseg/evaluation/metrics/__init__.py
+++ b/mmseg/evaluation/metrics/__init__.py
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .citys_metric import CityscapesMetric
+from .depth_metric import DepthMetric
from .iou_metric import IoUMetric
-__all__ = ['IoUMetric', 'CityscapesMetric']
+__all__ = ['IoUMetric', 'CityscapesMetric', 'DepthMetric']
diff --git a/mmseg/evaluation/metrics/depth_metric.py b/mmseg/evaluation/metrics/depth_metric.py
new file mode 100644
index 00000000000..621d4a31c9f
--- /dev/null
+++ b/mmseg/evaluation/metrics/depth_metric.py
@@ -0,0 +1,212 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from collections import OrderedDict, defaultdict
+from typing import Dict, List, Optional, Sequence
+
+import cv2
+import numpy as np
+import torch
+from mmengine.dist import is_main_process
+from mmengine.evaluator import BaseMetric
+from mmengine.logging import MMLogger, print_log
+from mmengine.utils import mkdir_or_exist
+from prettytable import PrettyTable
+from torch import Tensor
+
+from mmseg.registry import METRICS
+
+
+@METRICS.register_module()
+class DepthMetric(BaseMetric):
+ """Depth estimation evaluation metric.
+
+ Args:
+ depth_metrics (List[str], optional): List of metrics to compute. If
+ not specified, defaults to all metrics in self.METRICS.
+ min_depth_eval (float): Minimum depth value for evaluation.
+ Defaults to 0.0.
+ max_depth_eval (float): Maximum depth value for evaluation.
+ Defaults to infinity.
+ crop_type (str, optional): Specifies the type of cropping to be used
+ during evaluation. This option can affect how the evaluation mask
+ is generated. Currently, 'nyu_crop' is supported, but other
+ types can be added in future. Defaults to None if no cropping
+ should be applied.
+ depth_scale_factor (float): Factor to scale the depth values.
+ Defaults to 1.0.
+ collect_device (str): Device name used for collecting results from
+ different ranks during distributed training. Must be 'cpu' or
+ 'gpu'. Defaults to 'cpu'.
+ output_dir (str): The directory for output prediction. Defaults to
+ None.
+ format_only (bool): Only format result for results commit without
+ perform evaluation. It is useful when you want to save the result
+ to a specific format and submit it to the test server.
+ Defaults to False.
+ prefix (str, optional): The prefix that will be added in the metric
+ names to disambiguate homonymous metrics of different evaluators.
+ If prefix is not provided in the argument, self.default_prefix
+ will be used instead. Defaults to None.
+ """
+ METRICS = ('d1', 'd2', 'd3', 'abs_rel', 'sq_rel', 'rmse', 'rmse_log',
+ 'log10', 'silog')
+
+ def __init__(self,
+ depth_metrics: Optional[List[str]] = None,
+ min_depth_eval: float = 0.0,
+ max_depth_eval: float = float('inf'),
+ crop_type: Optional[str] = None,
+ depth_scale_factor: float = 1.0,
+ collect_device: str = 'cpu',
+ output_dir: Optional[str] = None,
+ format_only: bool = False,
+ prefix: Optional[str] = None,
+ **kwargs) -> None:
+ super().__init__(collect_device=collect_device, prefix=prefix)
+
+ if depth_metrics is None:
+ self.metrics = self.METRICS
+ elif isinstance(depth_metrics, [tuple, list]):
+ for metric in depth_metrics:
+ assert metric in self.METRICS, f'the metric {metric} is not ' \
+ f'supported. Please use metrics in {self.METRICS}'
+ self.metrics = depth_metrics
+
+ # Validate crop_type, if provided
+ assert crop_type in [
+ None, 'nyu_crop'
+ ], (f'Invalid value for crop_type: {crop_type}. Supported values are '
+ 'None or \'nyu_crop\'.')
+ self.crop_type = crop_type
+ self.min_depth_eval = min_depth_eval
+ self.max_depth_eval = max_depth_eval
+ self.output_dir = output_dir
+ if self.output_dir and is_main_process():
+ mkdir_or_exist(self.output_dir)
+ self.format_only = format_only
+ self.depth_scale_factor = depth_scale_factor
+
+ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
+ """Process one batch of data and data_samples.
+
+ The processed results should be stored in ``self.results``, which will
+ be used to compute the metrics when all batches have been processed.
+
+ Args:
+ data_batch (dict): A batch of data from the dataloader.
+ data_samples (Sequence[dict]): A batch of outputs from the model.
+ """
+ for data_sample in data_samples:
+ pred_label = data_sample['pred_depth_map']['data'].squeeze()
+ # format_only always for test dataset without ground truth
+ if not self.format_only:
+ gt_depth = data_sample['gt_depth_map']['data'].squeeze().to(
+ pred_label)
+
+ eval_mask = self._get_eval_mask(gt_depth)
+ self.results.append(
+ (gt_depth[eval_mask], pred_label[eval_mask]))
+ # format_result
+ if self.output_dir is not None:
+ basename = osp.splitext(osp.basename(
+ data_sample['img_path']))[0]
+ png_filename = osp.abspath(
+ osp.join(self.output_dir, f'{basename}.png'))
+ output_mask = pred_label.cpu().numpy(
+ ) * self.depth_scale_factor
+
+ cv2.imwrite(png_filename, output_mask.astype(np.uint16),
+ [cv2.IMWRITE_PNG_COMPRESSION, 0])
+
+ def _get_eval_mask(self, gt_depth: Tensor):
+ """Generates an evaluation mask based on ground truth depth and
+ cropping.
+
+ Args:
+ gt_depth (Tensor): Ground truth depth map.
+
+ Returns:
+ Tensor: Boolean mask where evaluation should be performed.
+ """
+ valid_mask = torch.logical_and(gt_depth > self.min_depth_eval,
+ gt_depth < self.max_depth_eval)
+
+ if self.crop_type == 'nyu_crop':
+ # this implementation is adapted from
+ # https://github.com/zhyever/Monocular-Depth-Estimation-Toolbox/blob/main/depth/datasets/nyu.py # noqa
+ crop_mask = torch.zeros_like(valid_mask)
+ crop_mask[45:471, 41:601] = 1
+ else:
+ crop_mask = torch.ones_like(valid_mask)
+
+ eval_mask = torch.logical_and(valid_mask, crop_mask)
+ return eval_mask
+
+ @staticmethod
+ def _calc_all_metrics(gt_depth, pred_depth):
+ """Computes final evaluation metrics based on accumulated results."""
+ assert gt_depth.shape == pred_depth.shape
+
+ thresh = torch.max((gt_depth / pred_depth), (pred_depth / gt_depth))
+ diff = pred_depth - gt_depth
+ diff_log = torch.log(pred_depth) - torch.log(gt_depth)
+
+ d1 = torch.sum(thresh < 1.25).float() / len(thresh)
+ d2 = torch.sum(thresh < 1.25**2).float() / len(thresh)
+ d3 = torch.sum(thresh < 1.25**3).float() / len(thresh)
+
+ abs_rel = torch.mean(torch.abs(diff) / gt_depth)
+ sq_rel = torch.mean(torch.pow(diff, 2) / gt_depth)
+
+ rmse = torch.sqrt(torch.mean(torch.pow(diff, 2)))
+ rmse_log = torch.sqrt(torch.mean(torch.pow(diff_log, 2)))
+
+ log10 = torch.mean(
+ torch.abs(torch.log10(pred_depth) - torch.log10(gt_depth)))
+ silog = torch.sqrt(
+ torch.pow(diff_log, 2).mean() -
+ 0.5 * torch.pow(diff_log.mean(), 2))
+
+ return {
+ 'd1': d1.item(),
+ 'd2': d2.item(),
+ 'd3': d3.item(),
+ 'abs_rel': abs_rel.item(),
+ 'sq_rel': sq_rel.item(),
+ 'rmse': rmse.item(),
+ 'rmse_log': rmse_log.item(),
+ 'log10': log10.item(),
+ 'silog': silog.item()
+ }
+
+ def compute_metrics(self, results: list) -> Dict[str, float]:
+ """Compute the metrics from processed results.
+
+ Args:
+ results (list): The processed results of each batch.
+
+ Returns:
+ Dict[str, float]: The computed metrics. The keys are the names of
+ the metrics, and the values are corresponding results. The keys
+ are identical with self.metrics.
+ """
+ logger: MMLogger = MMLogger.get_current_instance()
+ if self.format_only:
+ logger.info(f'results are saved to {osp.dirname(self.output_dir)}')
+ return OrderedDict()
+
+ metrics = defaultdict(list)
+ for gt_depth, pred_depth in results:
+ for key, value in self._calc_all_metrics(gt_depth,
+ pred_depth).items():
+ metrics[key].append(value)
+ metrics = {k: sum(metrics[k]) / len(metrics[k]) for k in self.metrics}
+
+ table_data = PrettyTable()
+ for key, val in metrics.items():
+ table_data.add_column(key, [round(val, 5)])
+
+ print_log('results:', logger)
+ print_log('\n' + table_data.get_string(), logger=logger)
+
+ return metrics
diff --git a/mmseg/models/__init__.py b/mmseg/models/__init__.py
index 7a520fb2fa4..a98951283c1 100644
--- a/mmseg/models/__init__.py
+++ b/mmseg/models/__init__.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from .assigners import * # noqa: F401,F403
from .backbones import * # noqa: F401,F403
from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone,
build_head, build_loss, build_segmentor)
@@ -7,6 +8,7 @@
from .losses import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .segmentors import * # noqa: F401,F403
+from .text_encoder import * # noqa: F401,F403
__all__ = [
'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone',
diff --git a/mmseg/models/assigners/__init__.py b/mmseg/models/assigners/__init__.py
new file mode 100644
index 00000000000..d49b1b18b9e
--- /dev/null
+++ b/mmseg/models/assigners/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base_assigner import BaseAssigner
+from .hungarian_assigner import HungarianAssigner
+from .match_cost import ClassificationCost, CrossEntropyLossCost, DiceCost
+
+__all__ = [
+ 'BaseAssigner',
+ 'HungarianAssigner',
+ 'ClassificationCost',
+ 'CrossEntropyLossCost',
+ 'DiceCost',
+]
diff --git a/mmseg/models/assigners/base_assigner.py b/mmseg/models/assigners/base_assigner.py
new file mode 100644
index 00000000000..97895cdac27
--- /dev/null
+++ b/mmseg/models/assigners/base_assigner.py
@@ -0,0 +1,18 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+from typing import Optional
+
+from mmengine.structures import InstanceData
+
+
+class BaseAssigner(metaclass=ABCMeta):
+ """Base assigner that assigns masks to ground truth class labels."""
+
+ @abstractmethod
+ def assign(self,
+ pred_instances: InstanceData,
+ gt_instances: InstanceData,
+ gt_instances_ignore: Optional[InstanceData] = None,
+ **kwargs):
+ """Assign masks to either a ground truth class label or a negative
+ label."""
diff --git a/mmseg/models/assigners/hungarian_assigner.py b/mmseg/models/assigners/hungarian_assigner.py
new file mode 100644
index 00000000000..28868f0a04e
--- /dev/null
+++ b/mmseg/models/assigners/hungarian_assigner.py
@@ -0,0 +1,86 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Union
+
+import torch
+from mmengine import ConfigDict
+from mmengine.structures import InstanceData
+from scipy.optimize import linear_sum_assignment
+from torch.cuda.amp import autocast
+
+from mmseg.registry import TASK_UTILS
+from .base_assigner import BaseAssigner
+
+
+@TASK_UTILS.register_module()
+class HungarianAssigner(BaseAssigner):
+ """Computes one-to-one matching between prediction masks and ground truth.
+
+ This class uses bipartite matching-based assignment to computes an
+ assignment between the prediction masks and the ground truth. The
+ assignment result is based on the weighted sum of match costs. The
+ Hungarian algorithm is used to calculate the best matching with the
+ minimum cost. The prediction masks that are not matched are classified
+ as background.
+
+ Args:
+ match_costs (ConfigDict|List[ConfigDict]): Match cost configs.
+ """
+
+ def __init__(
+ self, match_costs: Union[List[Union[dict, ConfigDict]], dict,
+ ConfigDict]
+ ) -> None:
+
+ if isinstance(match_costs, dict):
+ match_costs = [match_costs]
+ elif isinstance(match_costs, list):
+ assert len(match_costs) > 0, \
+ 'match_costs must not be a empty list.'
+
+ self.match_costs = [
+ TASK_UTILS.build(match_cost) for match_cost in match_costs
+ ]
+
+ def assign(self, pred_instances: InstanceData, gt_instances: InstanceData,
+ **kwargs):
+ """Computes one-to-one matching based on the weighted costs.
+
+ This method assign each query prediction to a ground truth or
+ background. The assignment first calculates the cost for each
+ category assigned to each query mask, and then uses the
+ Hungarian algorithm to calculate the minimum cost as the best
+ match.
+
+ Args:
+ pred_instances (InstanceData): Instances of model
+ predictions. It includes "masks", with shape
+ (n, h, w) or (n, l), and "cls", with shape (n, num_classes+1)
+ gt_instances (InstanceData): Ground truth of instance
+ annotations. It includes "labels", with shape (k, ),
+ and "masks", with shape (k, h, w) or (k, l).
+
+ Returns:
+ matched_quiery_inds (Tensor): The indexes of matched quieres.
+ matched_label_inds (Tensor): The indexes of matched labels.
+ """
+ # compute weighted cost
+ cost_list = []
+ with autocast(enabled=False):
+ for match_cost in self.match_costs:
+ cost = match_cost(
+ pred_instances=pred_instances, gt_instances=gt_instances)
+ cost_list.append(cost)
+ cost = torch.stack(cost_list).sum(dim=0)
+
+ device = cost.device
+ # do Hungarian matching on CPU using linear_sum_assignment
+ cost = cost.detach().cpu()
+ if linear_sum_assignment is None:
+ raise ImportError('Please run "pip install scipy" '
+ 'to install scipy first.')
+
+ matched_quiery_inds, matched_label_inds = linear_sum_assignment(cost)
+ matched_quiery_inds = torch.from_numpy(matched_quiery_inds).to(device)
+ matched_label_inds = torch.from_numpy(matched_label_inds).to(device)
+
+ return matched_quiery_inds, matched_label_inds
diff --git a/mmseg/models/assigners/match_cost.py b/mmseg/models/assigners/match_cost.py
new file mode 100644
index 00000000000..560df852902
--- /dev/null
+++ b/mmseg/models/assigners/match_cost.py
@@ -0,0 +1,231 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import abstractmethod
+from typing import Union
+
+import torch
+import torch.nn.functional as F
+from mmengine.structures import InstanceData
+from torch import Tensor
+
+from mmseg.registry import TASK_UTILS
+
+
+class BaseMatchCost:
+ """Base match cost class.
+
+ Args:
+ weight (Union[float, int]): Cost weight. Defaults to 1.
+ """
+
+ def __init__(self, weight: Union[float, int] = 1.) -> None:
+ self.weight = weight
+
+ @abstractmethod
+ def __call__(self, pred_instances: InstanceData,
+ gt_instances: InstanceData, **kwargs) -> Tensor:
+ """Compute match cost.
+
+ Args:
+ pred_instances (InstanceData): Instances of model predictions.
+ It often includes "labels" and "scores".
+ gt_instances (InstanceData): Ground truth of instance
+ annotations. It usually includes "labels".
+
+ Returns:
+ Tensor: Match Cost matrix of shape (num_preds, num_gts).
+ """
+ pass
+
+
+@TASK_UTILS.register_module()
+class ClassificationCost(BaseMatchCost):
+ """ClsSoftmaxCost.
+
+ Args:
+ weight (Union[float, int]): Cost weight. Defaults to 1.
+
+ Examples:
+ >>> from mmseg.models.assigners import ClassificationCost
+ >>> import torch
+ >>> self = ClassificationCost()
+ >>> cls_pred = torch.rand(4, 3)
+ >>> gt_labels = torch.tensor([0, 1, 2])
+ >>> factor = torch.tensor([10, 8, 10, 8])
+ >>> self(cls_pred, gt_labels)
+ tensor([[-0.3430, -0.3525, -0.3045],
+ [-0.3077, -0.2931, -0.3992],
+ [-0.3664, -0.3455, -0.2881],
+ [-0.3343, -0.2701, -0.3956]])
+ """
+
+ def __init__(self, weight: Union[float, int] = 1) -> None:
+ super().__init__(weight=weight)
+
+ def __call__(self, pred_instances: InstanceData,
+ gt_instances: InstanceData, **kwargs) -> Tensor:
+ """Compute match cost.
+
+ Args:
+ pred_instances (InstanceData): "scores" inside is
+ predicted classification logits, of shape
+ (num_queries, num_class).
+ gt_instances (InstanceData): "labels" inside should have
+ shape (num_gt, ).
+
+ Returns:
+ Tensor: Match Cost matrix of shape (num_preds, num_gts).
+ """
+ assert hasattr(pred_instances, 'scores'), \
+ "pred_instances must contain 'scores'"
+ assert hasattr(gt_instances, 'labels'), \
+ "gt_instances must contain 'labels'"
+ pred_scores = pred_instances.scores
+ gt_labels = gt_instances.labels
+
+ pred_scores = pred_scores.softmax(-1)
+ cls_cost = -pred_scores[:, gt_labels]
+
+ return cls_cost * self.weight
+
+
+@TASK_UTILS.register_module()
+class DiceCost(BaseMatchCost):
+ """Cost of mask assignments based on dice losses.
+
+ Args:
+ pred_act (bool): Whether to apply sigmoid to mask_pred.
+ Defaults to False.
+ eps (float): Defaults to 1e-3.
+ naive_dice (bool): If True, use the naive dice loss
+ in which the power of the number in the denominator is
+ the first power. If False, use the second power that
+ is adopted by K-Net and SOLO. Defaults to True.
+ weight (Union[float, int]): Cost weight. Defaults to 1.
+ """
+
+ def __init__(self,
+ pred_act: bool = False,
+ eps: float = 1e-3,
+ naive_dice: bool = True,
+ weight: Union[float, int] = 1.) -> None:
+ super().__init__(weight=weight)
+ self.pred_act = pred_act
+ self.eps = eps
+ self.naive_dice = naive_dice
+
+ def _binary_mask_dice_loss(self, mask_preds: Tensor,
+ gt_masks: Tensor) -> Tensor:
+ """
+ Args:
+ mask_preds (Tensor): Mask prediction in shape (num_queries, *).
+ gt_masks (Tensor): Ground truth in shape (num_gt, *)
+ store 0 or 1, 0 for negative class and 1 for
+ positive class.
+
+ Returns:
+ Tensor: Dice cost matrix in shape (num_queries, num_gt).
+ """
+ mask_preds = mask_preds.flatten(1)
+ gt_masks = gt_masks.flatten(1).float()
+ numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks)
+ if self.naive_dice:
+ denominator = mask_preds.sum(-1)[:, None] + \
+ gt_masks.sum(-1)[None, :]
+ else:
+ denominator = mask_preds.pow(2).sum(1)[:, None] + \
+ gt_masks.pow(2).sum(1)[None, :]
+ loss = 1 - (numerator + self.eps) / (denominator + self.eps)
+ return loss
+
+ def __call__(self, pred_instances: InstanceData,
+ gt_instances: InstanceData, **kwargs) -> Tensor:
+ """Compute match cost.
+
+ Args:
+ pred_instances (InstanceData): Predicted instances which
+ must contain "masks".
+ gt_instances (InstanceData): Ground truth which must contain
+ "mask".
+
+ Returns:
+ Tensor: Match Cost matrix of shape (num_preds, num_gts).
+ """
+ assert hasattr(pred_instances, 'masks'), \
+ "pred_instances must contain 'masks'"
+ assert hasattr(gt_instances, 'masks'), \
+ "gt_instances must contain 'masks'"
+ pred_masks = pred_instances.masks
+ gt_masks = gt_instances.masks
+
+ if self.pred_act:
+ pred_masks = pred_masks.sigmoid()
+ dice_cost = self._binary_mask_dice_loss(pred_masks, gt_masks)
+ return dice_cost * self.weight
+
+
+@TASK_UTILS.register_module()
+class CrossEntropyLossCost(BaseMatchCost):
+ """CrossEntropyLossCost.
+
+ Args:
+ use_sigmoid (bool): Whether the prediction uses sigmoid
+ of softmax. Defaults to True.
+ weight (Union[float, int]): Cost weight. Defaults to 1.
+ """
+
+ def __init__(self,
+ use_sigmoid: bool = True,
+ weight: Union[float, int] = 1.) -> None:
+ super().__init__(weight=weight)
+ self.use_sigmoid = use_sigmoid
+
+ def _binary_cross_entropy(self, cls_pred: Tensor,
+ gt_labels: Tensor) -> Tensor:
+ """
+ Args:
+ cls_pred (Tensor): The prediction with shape (num_queries, 1, *) or
+ (num_queries, *).
+ gt_labels (Tensor): The learning label of prediction with
+ shape (num_gt, *).
+
+ Returns:
+ Tensor: Cross entropy cost matrix in shape (num_queries, num_gt).
+ """
+ cls_pred = cls_pred.flatten(1).float()
+ gt_labels = gt_labels.flatten(1).float()
+ n = cls_pred.shape[1]
+ pos = F.binary_cross_entropy_with_logits(
+ cls_pred, torch.ones_like(cls_pred), reduction='none')
+ neg = F.binary_cross_entropy_with_logits(
+ cls_pred, torch.zeros_like(cls_pred), reduction='none')
+ cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \
+ torch.einsum('nc,mc->nm', neg, 1 - gt_labels)
+ cls_cost = cls_cost / n
+
+ return cls_cost
+
+ def __call__(self, pred_instances: InstanceData,
+ gt_instances: InstanceData, **kwargs) -> Tensor:
+ """Compute match cost.
+
+ Args:
+ pred_instances (:obj:`InstanceData`): Predicted instances which
+ must contain ``masks``.
+ gt_instances (:obj:`InstanceData`): Ground truth which must contain
+ ``masks``.
+
+ Returns:
+ Tensor: Match Cost matrix of shape (num_preds, num_gts).
+ """
+ assert hasattr(pred_instances, 'masks'), \
+ "pred_instances must contain 'masks'"
+ assert hasattr(gt_instances, 'masks'), \
+ "gt_instances must contain 'masks'"
+ pred_masks = pred_instances.masks
+ gt_masks = gt_instances.masks
+ if self.use_sigmoid:
+ cls_cost = self._binary_cross_entropy(pred_masks, gt_masks)
+ else:
+ raise NotImplementedError
+
+ return cls_cost * self.weight
diff --git a/mmseg/models/backbones/__init__.py b/mmseg/models/backbones/__init__.py
index e3107306eae..784d3dfdb70 100644
--- a/mmseg/models/backbones/__init__.py
+++ b/mmseg/models/backbones/__init__.py
@@ -3,6 +3,7 @@
from .bisenetv1 import BiSeNetV1
from .bisenetv2 import BiSeNetV2
from .cgnet import CGNet
+from .ddrnet import DDRNet
from .erfnet import ERFNet
from .fast_scnn import FastSCNN
from .hrnet import HRNet
@@ -22,11 +23,13 @@
from .twins import PCPVT, SVT
from .unet import UNet
from .vit import VisionTransformer
+from .vpd import VPD
__all__ = [
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
- 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN'
+ 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN',
+ 'DDRNet', 'VPD'
]
diff --git a/mmseg/models/backbones/ddrnet.py b/mmseg/models/backbones/ddrnet.py
new file mode 100644
index 00000000000..4508aade82b
--- /dev/null
+++ b/mmseg/models/backbones/ddrnet.py
@@ -0,0 +1,222 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from mmcv.cnn import ConvModule, build_norm_layer
+from mmengine.model import BaseModule
+
+from mmseg.models.utils import DAPPM, BasicBlock, Bottleneck, resize
+from mmseg.registry import MODELS
+from mmseg.utils import OptConfigType
+
+
+@MODELS.register_module()
+class DDRNet(BaseModule):
+ """DDRNet backbone.
+
+ This backbone is the implementation of `Deep Dual-resolution Networks for
+ Real-time and Accurate Semantic Segmentation of Road Scenes
+ `_.
+ Modified from https://github.com/ydhongHIT/DDRNet.
+
+ Args:
+ in_channels (int): Number of input image channels. Default: 3.
+ channels: (int): The base channels of DDRNet. Default: 32.
+ ppm_channels (int): The channels of PPM module. Default: 128.
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False.
+ norm_cfg (dict): Config dict to build norm layer.
+ Default: dict(type='BN', requires_grad=True).
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU', inplace=True).
+ init_cfg (dict, optional): Initialization config dict.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_channels: int = 3,
+ channels: int = 32,
+ ppm_channels: int = 128,
+ align_corners: bool = False,
+ norm_cfg: OptConfigType = dict(type='BN', requires_grad=True),
+ act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
+ init_cfg: OptConfigType = None):
+ super().__init__(init_cfg)
+
+ self.in_channels = in_channels
+ self.ppm_channels = ppm_channels
+
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.align_corners = align_corners
+
+ # stage 0-2
+ self.stem = self._make_stem_layer(in_channels, channels, num_blocks=2)
+ self.relu = nn.ReLU()
+
+ # low resolution(context) branch
+ self.context_branch_layers = nn.ModuleList()
+ for i in range(3):
+ self.context_branch_layers.append(
+ self._make_layer(
+ block=BasicBlock if i < 2 else Bottleneck,
+ inplanes=channels * 2**(i + 1),
+ planes=channels * 8 if i > 0 else channels * 4,
+ num_blocks=2 if i < 2 else 1,
+ stride=2))
+
+ # bilateral fusion
+ self.compression_1 = ConvModule(
+ channels * 4,
+ channels * 2,
+ kernel_size=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None)
+ self.down_1 = ConvModule(
+ channels * 2,
+ channels * 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None)
+
+ self.compression_2 = ConvModule(
+ channels * 8,
+ channels * 2,
+ kernel_size=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None)
+ self.down_2 = nn.Sequential(
+ ConvModule(
+ channels * 2,
+ channels * 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg),
+ ConvModule(
+ channels * 4,
+ channels * 8,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None))
+
+ # high resolution(spatial) branch
+ self.spatial_branch_layers = nn.ModuleList()
+ for i in range(3):
+ self.spatial_branch_layers.append(
+ self._make_layer(
+ block=BasicBlock if i < 2 else Bottleneck,
+ inplanes=channels * 2,
+ planes=channels * 2,
+ num_blocks=2 if i < 2 else 1,
+ ))
+
+ self.spp = DAPPM(
+ channels * 16, ppm_channels, channels * 4, num_scales=5)
+
+ def _make_stem_layer(self, in_channels, channels, num_blocks):
+ layers = [
+ ConvModule(
+ in_channels,
+ channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg),
+ ConvModule(
+ channels,
+ channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ ]
+
+ layers.extend([
+ self._make_layer(BasicBlock, channels, channels, num_blocks),
+ nn.ReLU(),
+ self._make_layer(
+ BasicBlock, channels, channels * 2, num_blocks, stride=2),
+ nn.ReLU(),
+ ])
+
+ return nn.Sequential(*layers)
+
+ def _make_layer(self, block, inplanes, planes, num_blocks, stride=1):
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
+
+ layers = [
+ block(
+ in_channels=inplanes,
+ channels=planes,
+ stride=stride,
+ downsample=downsample)
+ ]
+ inplanes = planes * block.expansion
+ for i in range(1, num_blocks):
+ layers.append(
+ block(
+ in_channels=inplanes,
+ channels=planes,
+ stride=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg_out=None if i == num_blocks - 1 else self.act_cfg))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ """Forward function."""
+ out_size = (x.shape[-2] // 8, x.shape[-1] // 8)
+
+ # stage 0-2
+ x = self.stem(x)
+
+ # stage3
+ x_c = self.context_branch_layers[0](x)
+ x_s = self.spatial_branch_layers[0](x)
+ comp_c = self.compression_1(self.relu(x_c))
+ x_c += self.down_1(self.relu(x_s))
+ x_s += resize(
+ comp_c,
+ size=out_size,
+ mode='bilinear',
+ align_corners=self.align_corners)
+ if self.training:
+ temp_context = x_s.clone()
+
+ # stage4
+ x_c = self.context_branch_layers[1](self.relu(x_c))
+ x_s = self.spatial_branch_layers[1](self.relu(x_s))
+ comp_c = self.compression_2(self.relu(x_c))
+ x_c += self.down_2(self.relu(x_s))
+ x_s += resize(
+ comp_c,
+ size=out_size,
+ mode='bilinear',
+ align_corners=self.align_corners)
+
+ # stage5
+ x_s = self.spatial_branch_layers[2](self.relu(x_s))
+ x_c = self.context_branch_layers[2](self.relu(x_c))
+ x_c = self.spp(x_c)
+ x_c = resize(
+ x_c,
+ size=out_size,
+ mode='bilinear',
+ align_corners=self.align_corners)
+
+ return (temp_context, x_s + x_c) if self.training else x_s + x_c
diff --git a/mmseg/models/backbones/swin.py b/mmseg/models/backbones/swin.py
index c0ace3c1391..67b28a96e15 100644
--- a/mmseg/models/backbones/swin.py
+++ b/mmseg/models/backbones/swin.py
@@ -716,20 +716,22 @@ def init_weights(self):
]
for table_key in relative_position_bias_table_keys:
table_pretrained = state_dict[table_key]
- table_current = self.state_dict()[table_key]
- L1, nH1 = table_pretrained.size()
- L2, nH2 = table_current.size()
- if nH1 != nH2:
- print_log(f'Error in loading {table_key}, pass')
- elif L1 != L2:
- S1 = int(L1**0.5)
- S2 = int(L2**0.5)
- table_pretrained_resized = F.interpolate(
- table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),
- size=(S2, S2),
- mode='bicubic')
- state_dict[table_key] = table_pretrained_resized.view(
- nH2, L2).permute(1, 0).contiguous()
+ if table_key in self.state_dict():
+ table_current = self.state_dict()[table_key]
+ L1, nH1 = table_pretrained.size()
+ L2, nH2 = table_current.size()
+ if nH1 != nH2:
+ print_log(f'Error in loading {table_key}, pass')
+ elif L1 != L2:
+ S1 = int(L1**0.5)
+ S2 = int(L2**0.5)
+ table_pretrained_resized = F.interpolate(
+ table_pretrained.permute(1, 0).reshape(
+ 1, nH1, S1, S1),
+ size=(S2, S2),
+ mode='bicubic')
+ state_dict[table_key] = table_pretrained_resized.view(
+ nH2, L2).permute(1, 0).contiguous()
# load state_dict
self.load_state_dict(state_dict, strict=False)
diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py
index 3c96f654937..dd0f688fcc4 100644
--- a/mmseg/models/backbones/vit.py
+++ b/mmseg/models/backbones/vit.py
@@ -132,12 +132,16 @@ class VisionTransformer(BaseModule):
Args:
img_size (int | tuple): Input image size. Default: 224.
patch_size (int): The patch size. Default: 16.
+ patch_pad (str | int | None): The padding method in patch embedding.
+ Default: 'corner'.
in_channels (int): Number of input channels. Default: 3.
embed_dims (int): embedding dimension. Default: 768.
num_layers (int): depth of transformer. Default: 12.
num_heads (int): number of attention heads. Default: 12.
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
Default: 4.
+ out_origin (bool): Whether to output the original input embedding.
+ Default: False
out_indices (list | tuple | int): Output from which stages.
Default: -1.
qkv_bias (bool): enable bias for qkv if True. Default: True.
@@ -154,8 +158,12 @@ class VisionTransformer(BaseModule):
Default: dict(type='LN')
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
+ patch_bias (dict): Whether use bias in convolution of PatchEmbed Block.
+ Default: True.
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
Default: False.
+ pre_norm (bool): Whether to add a norm before Transformer Layers.
+ Default: False.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Default: False.
interpolate_mode (str): Select the interpolate mode for position
@@ -167,6 +175,8 @@ class VisionTransformer(BaseModule):
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save
some memory while slowing down the training speed. Default: False.
+ frozen_exclude (List): List of parameters that are not to be frozen.
+ Default: ["all"], "all" means there are no frozen parameters.
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
@@ -175,11 +185,13 @@ class VisionTransformer(BaseModule):
def __init__(self,
img_size=224,
patch_size=16,
+ patch_pad='corner',
in_channels=3,
embed_dims=768,
num_layers=12,
num_heads=12,
mlp_ratio=4,
+ out_origin=False,
out_indices=-1,
qkv_bias=True,
drop_rate=0.,
@@ -190,11 +202,14 @@ def __init__(self,
norm_cfg=dict(type='LN'),
act_cfg=dict(type='GELU'),
patch_norm=False,
+ patch_bias=False,
+ pre_norm=False,
final_norm=False,
interpolate_mode='bicubic',
num_fcs=2,
norm_eval=False,
with_cp=False,
+ frozen_exclude=['all'],
pretrained=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
@@ -227,6 +242,8 @@ def __init__(self,
self.norm_eval = norm_eval
self.with_cp = with_cp
self.pretrained = pretrained
+ self.out_origin = out_origin
+ self.frozen_exclude = frozen_exclude
self.patch_embed = PatchEmbed(
in_channels=in_channels,
@@ -234,7 +251,8 @@ def __init__(self,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
- padding='corner',
+ padding=patch_pad,
+ bias=patch_bias,
norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None,
)
@@ -248,6 +266,12 @@ def __init__(self,
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dims))
self.drop_after_pos = nn.Dropout(p=drop_rate)
+ self.pre_norm = pre_norm
+
+ if self.pre_norm:
+ self.pre_ln_name, pre_ln = build_norm_layer(
+ norm_cfg, embed_dims, postfix='_pre')
+ self.add_module(self.pre_ln_name, pre_ln)
if isinstance(out_indices, int):
if out_indices == -1:
@@ -285,20 +309,36 @@ def __init__(self,
norm_cfg, embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
+ self._freeze()
+
+ @property
+ def pre_ln(self):
+ return getattr(self, self.pre_ln_name)
+
@property
def norm1(self):
return getattr(self, self.norm1_name)
def init_weights(self):
- if (isinstance(self.init_cfg, dict)
- and self.init_cfg.get('type') == 'Pretrained'):
+ if isinstance(self.init_cfg, dict) and \
+ self.init_cfg.get('type') in ['Pretrained', 'Pretrained_Part']:
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
- if 'state_dict' in checkpoint:
- state_dict = checkpoint['state_dict']
- else:
- state_dict = checkpoint
+ if self.init_cfg.get('type') == 'Pretrained':
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ else:
+ state_dict = checkpoint
+
+ elif self.init_cfg.get('type') == 'Pretrained_Part':
+ state_dict = checkpoint.copy()
+ para_prefix = 'image_encoder'
+ prefix_len = len(para_prefix) + 1
+ for k, v in checkpoint.items():
+ state_dict.pop(k)
+ if para_prefix in k:
+ state_dict[k[prefix_len:]] = v
if 'pos_embed' in state_dict.keys():
if self.pos_embed.shape != state_dict['pos_embed'].shape:
@@ -334,6 +374,13 @@ def init_weights(self):
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
constant_init(m, val=1.0, bias=0.)
+ def _freeze(self):
+ if 'all' in self.frozen_exclude:
+ return
+ for name, param in self.named_parameters():
+ if not any([exclude in name for exclude in self.frozen_exclude]):
+ param.requires_grad = False
+
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
"""Positioning embeding method.
@@ -409,7 +456,23 @@ def forward(self, inputs):
# Remove class token for transformer encoder input
x = x[:, 1:]
+ if self.pre_norm:
+ x = self.pre_ln(x)
+
outs = []
+ if self.out_origin:
+ if self.with_cls_token:
+ # Remove class token and reshape token for decoder head
+ out = x[:, 1:]
+ else:
+ out = x
+ B, _, C = out.shape
+ out = out.reshape(B, hw_shape[0], hw_shape[1],
+ C).permute(0, 3, 1, 2).contiguous()
+ if self.output_cls_token:
+ out = [out, x[:, 0]]
+ outs.append(out)
+
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1:
diff --git a/mmseg/models/backbones/vpd.py b/mmseg/models/backbones/vpd.py
new file mode 100644
index 00000000000..e0536d31c64
--- /dev/null
+++ b/mmseg/models/backbones/vpd.py
@@ -0,0 +1,395 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# ------------------------------------------------------------------------------
+# Adapted from https://github.com/wl-zhao/VPD/blob/main/vpd/models.py
+# Original licence: MIT License
+# ------------------------------------------------------------------------------
+
+import math
+from typing import List, Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmengine.model import BaseModule
+from mmengine.runner import CheckpointLoader, load_checkpoint
+
+from mmseg.registry import MODELS
+from mmseg.utils import ConfigType, OptConfigType
+
+try:
+ from ldm.modules.diffusionmodules.util import timestep_embedding
+ from ldm.util import instantiate_from_config
+ has_ldm = True
+except ImportError:
+ has_ldm = False
+
+
+def register_attention_control(model, controller):
+ """Registers a control function to manage attention within a model.
+
+ Args:
+ model: The model to which attention is to be registered.
+ controller: The control function responsible for managing attention.
+ """
+
+ def ca_forward(self, place_in_unet):
+ """Custom forward method for attention.
+
+ Args:
+ self: Reference to the current object.
+ place_in_unet: The location in UNet (down/mid/up).
+
+ Returns:
+ The modified forward method.
+ """
+
+ def forward(x, context=None, mask=None):
+ h = self.heads
+ is_cross = context is not None
+ context = context or x # if context is None, use x
+
+ q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
+ q, k, v = (
+ tensor.view(tensor.shape[0] * h, tensor.shape[1],
+ tensor.shape[2] // h) for tensor in [q, k, v])
+
+ sim = torch.matmul(q, k.transpose(-2, -1)) * self.scale
+
+ if mask is not None:
+ mask = mask.flatten(1).unsqueeze(1).repeat(h, 1, 1)
+ max_neg_value = -torch.finfo(sim.dtype).max
+ sim.masked_fill_(~mask, max_neg_value)
+
+ attn = sim.softmax(dim=-1)
+ attn_mean = attn.view(h, attn.shape[0] // h,
+ *attn.shape[1:]).mean(0)
+ controller(attn_mean, is_cross, place_in_unet)
+
+ out = torch.matmul(attn, v)
+ out = out.view(out.shape[0] // h, out.shape[1], out.shape[2] * h)
+ return self.to_out(out)
+
+ return forward
+
+ def register_recr(net_, count, place_in_unet):
+ """Recursive function to register the custom forward method to all
+ CrossAttention layers.
+
+ Args:
+ net_: The network layer currently being processed.
+ count: The current count of layers processed.
+ place_in_unet: The location in UNet (down/mid/up).
+
+ Returns:
+ The updated count of layers processed.
+ """
+ if net_.__class__.__name__ == 'CrossAttention':
+ net_.forward = ca_forward(net_, place_in_unet)
+ return count + 1
+ if hasattr(net_, 'children'):
+ return sum(
+ register_recr(child, 0, place_in_unet)
+ for child in net_.children())
+ return count
+
+ cross_att_count = sum(
+ register_recr(net[1], 0, place) for net, place in [
+ (child, 'down') if 'input_blocks' in name else (
+ child, 'up') if 'output_blocks' in name else
+ (child,
+ 'mid') if 'middle_block' in name else (None, None) # Default case
+ for name, child in model.diffusion_model.named_children()
+ ] if net is not None)
+
+ controller.num_att_layers = cross_att_count
+
+
+class AttentionStore:
+ """A class for storing attention information in the UNet model.
+
+ Attributes:
+ base_size (int): Base size for storing attention information.
+ max_size (int): Maximum size for storing attention information.
+ """
+
+ def __init__(self, base_size=64, max_size=None):
+ """Initialize AttentionStore with default or custom sizes."""
+ self.reset()
+ self.base_size = base_size
+ self.max_size = max_size or (base_size // 2)
+ self.num_att_layers = -1
+
+ @staticmethod
+ def get_empty_store():
+ """Returns an empty store for holding attention values."""
+ return {
+ key: []
+ for key in [
+ 'down_cross', 'mid_cross', 'up_cross', 'down_self', 'mid_self',
+ 'up_self'
+ ]
+ }
+
+ def reset(self):
+ """Resets the step and attention stores to their initial states."""
+ self.cur_step = 0
+ self.cur_att_layer = 0
+ self.step_store = self.get_empty_store()
+ self.attention_store = {}
+
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
+ """Processes a single forward step, storing the attention.
+
+ Args:
+ attn: The attention tensor.
+ is_cross (bool): Whether it's cross attention.
+ place_in_unet (str): The location in UNet (down/mid/up).
+
+ Returns:
+ The unmodified attention tensor.
+ """
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
+ if attn.shape[1] <= (self.max_size)**2:
+ self.step_store[key].append(attn)
+ return attn
+
+ def between_steps(self):
+ """Processes and stores attention information between steps."""
+ if not self.attention_store:
+ self.attention_store = self.step_store
+ else:
+ for key in self.attention_store:
+ self.attention_store[key] = [
+ stored + step for stored, step in zip(
+ self.attention_store[key], self.step_store[key])
+ ]
+ self.step_store = self.get_empty_store()
+
+ def get_average_attention(self):
+ """Calculates and returns the average attention across all steps."""
+ return {
+ key: [item for item in self.step_store[key]]
+ for key in self.step_store
+ }
+
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
+ """Allows the class instance to be callable."""
+ return self.forward(attn, is_cross, place_in_unet)
+
+ @property
+ def num_uncond_att_layers(self):
+ """Returns the number of unconditional attention layers (default is
+ 0)."""
+ return 0
+
+ def step_callback(self, x_t):
+ """A placeholder for a step callback.
+
+ Returns the input unchanged.
+ """
+ return x_t
+
+
+class UNetWrapper(nn.Module):
+ """A wrapper for UNet with optional attention mechanisms.
+
+ Args:
+ unet (nn.Module): The UNet model to wrap
+ use_attn (bool): Whether to use attention. Defaults to True
+ base_size (int): Base size for the attention store. Defaults to 512
+ max_attn_size (int, optional): Maximum size for the attention store.
+ Defaults to None
+ attn_selector (str): The types of attention to use.
+ Defaults to 'up_cross+down_cross'
+ """
+
+ def __init__(self,
+ unet,
+ use_attn=True,
+ base_size=512,
+ max_attn_size=None,
+ attn_selector='up_cross+down_cross'):
+ super().__init__()
+
+ assert has_ldm, 'To use UNetWrapper, please install required ' \
+ 'packages via `pip install -r requirements/optional.txt`.'
+
+ self.unet = unet
+ self.attention_store = AttentionStore(
+ base_size=base_size // 8, max_size=max_attn_size)
+ self.attn_selector = attn_selector.split('+')
+ self.use_attn = use_attn
+ self.init_sizes(base_size)
+ if self.use_attn:
+ register_attention_control(unet, self.attention_store)
+
+ def init_sizes(self, base_size):
+ """Initialize sizes based on the base size."""
+ self.size16 = base_size // 32
+ self.size32 = base_size // 16
+ self.size64 = base_size // 8
+
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
+ """Forward pass through the model."""
+ diffusion_model = self.unet.diffusion_model
+ if self.use_attn:
+ self.attention_store.reset()
+ hs, emb, out_list = self._unet_forward(x, timesteps, context, y,
+ diffusion_model)
+ if self.use_attn:
+ self._append_attn_to_output(out_list)
+ return out_list[::-1]
+
+ def _unet_forward(self, x, timesteps, context, y, diffusion_model):
+ hs = []
+ t_emb = timestep_embedding(
+ timesteps, diffusion_model.model_channels, repeat_only=False)
+ emb = diffusion_model.time_embed(t_emb)
+ h = x.type(diffusion_model.dtype)
+ for module in diffusion_model.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = diffusion_model.middle_block(h, emb, context)
+ out_list = []
+ for i_out, module in enumerate(diffusion_model.output_blocks):
+ h = torch.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ if i_out in [1, 4, 7]:
+ out_list.append(h)
+ h = h.type(x.dtype)
+ out_list.append(h)
+ return hs, emb, out_list
+
+ def _append_attn_to_output(self, out_list):
+ avg_attn = self.attention_store.get_average_attention()
+ attns = {self.size16: [], self.size32: [], self.size64: []}
+ for k in self.attn_selector:
+ for up_attn in avg_attn[k]:
+ size = int(math.sqrt(up_attn.shape[1]))
+ up_attn = up_attn.transpose(-1, -2).reshape(
+ *up_attn.shape[:2], size, -1)
+ attns[size].append(up_attn)
+ attn16 = torch.stack(attns[self.size16]).mean(0)
+ attn32 = torch.stack(attns[self.size32]).mean(0)
+ attn64 = torch.stack(attns[self.size64]).mean(0) if len(
+ attns[self.size64]) > 0 else None
+ out_list[1] = torch.cat([out_list[1], attn16], dim=1)
+ out_list[2] = torch.cat([out_list[2], attn32], dim=1)
+ if attn64 is not None:
+ out_list[3] = torch.cat([out_list[3], attn64], dim=1)
+
+
+class TextAdapter(nn.Module):
+ """A PyTorch Module that serves as a text adapter.
+
+ This module takes text embeddings and adjusts them based on a scaling
+ factor gamma.
+ """
+
+ def __init__(self, text_dim=768):
+ super().__init__()
+ self.fc = nn.Sequential(
+ nn.Linear(text_dim, text_dim), nn.GELU(),
+ nn.Linear(text_dim, text_dim))
+
+ def forward(self, texts, gamma):
+ texts_after = self.fc(texts)
+ texts = texts + gamma * texts_after
+ return texts
+
+
+@MODELS.register_module()
+class VPD(BaseModule):
+ """VPD (Visual Perception Diffusion) model.
+
+ .. _`VPD`: https://arxiv.org/abs/2303.02153
+
+ Args:
+ diffusion_cfg (dict): Configuration for diffusion model.
+ class_embed_path (str): Path for class embeddings.
+ unet_cfg (dict, optional): Configuration for U-Net.
+ gamma (float, optional): Gamma for text adaptation. Defaults to 1e-4.
+ class_embed_select (bool, optional): If True, enables class embedding
+ selection. Defaults to False.
+ pad_shape (Optional[Union[int, List[int]]], optional): Padding shape.
+ Defaults to None.
+ pad_val (Union[int, List[int]], optional): Padding value.
+ Defaults to 0.
+ init_cfg (dict, optional): Configuration for network initialization.
+ """
+
+ def __init__(self,
+ diffusion_cfg: ConfigType,
+ class_embed_path: str,
+ unet_cfg: OptConfigType = dict(),
+ gamma: float = 1e-4,
+ class_embed_select=False,
+ pad_shape: Optional[Union[int, List[int]]] = None,
+ pad_val: Union[int, List[int]] = 0,
+ init_cfg: OptConfigType = None):
+
+ super().__init__(init_cfg=init_cfg)
+
+ assert has_ldm, 'To use VPD model, please install required packages' \
+ ' via `pip install -r requirements/optional.txt`.'
+
+ if pad_shape is not None:
+ if not isinstance(pad_shape, (list, tuple)):
+ pad_shape = (pad_shape, pad_shape)
+
+ self.pad_shape = pad_shape
+ self.pad_val = pad_val
+
+ # diffusion model
+ diffusion_checkpoint = diffusion_cfg.pop('checkpoint', None)
+ sd_model = instantiate_from_config(diffusion_cfg)
+ if diffusion_checkpoint is not None:
+ load_checkpoint(sd_model, diffusion_checkpoint, strict=False)
+
+ self.encoder_vq = sd_model.first_stage_model
+ self.unet = UNetWrapper(sd_model.model, **unet_cfg)
+
+ # class embeddings & text adapter
+ class_embeddings = CheckpointLoader.load_checkpoint(class_embed_path)
+ text_dim = class_embeddings.size(-1)
+ self.text_adapter = TextAdapter(text_dim=text_dim)
+ self.class_embed_select = class_embed_select
+ if class_embed_select:
+ class_embeddings = torch.cat(
+ (class_embeddings, class_embeddings.mean(dim=0,
+ keepdims=True)),
+ dim=0)
+ self.register_buffer('class_embeddings', class_embeddings)
+ self.gamma = nn.Parameter(torch.ones(text_dim) * gamma)
+
+ def forward(self, x):
+ """Extract features from images."""
+
+ # calculate cross-attn map
+ if self.class_embed_select:
+ if isinstance(x, (tuple, list)):
+ x, class_ids = x[:2]
+ class_ids = class_ids.tolist()
+ else:
+ class_ids = [-1] * x.size(0)
+ class_embeddings = self.class_embeddings[class_ids]
+ c_crossattn = self.text_adapter(class_embeddings, self.gamma)
+ c_crossattn = c_crossattn.unsqueeze(1)
+ else:
+ class_embeddings = self.class_embeddings
+ c_crossattn = self.text_adapter(class_embeddings, self.gamma)
+ c_crossattn = c_crossattn.unsqueeze(0).repeat(x.size(0), 1, 1)
+
+ # pad to required input shape for pretrained diffusion model
+ if self.pad_shape is not None:
+ pad_width = max(0, self.pad_shape[1] - x.shape[-1])
+ pad_height = max(0, self.pad_shape[0] - x.shape[-2])
+ x = F.pad(x, (0, pad_width, 0, pad_height), value=self.pad_val)
+
+ # forward the denoising model
+ with torch.no_grad():
+ latents = self.encoder_vq.encode(x).mode().detach()
+ t = torch.ones((x.shape[0], ), device=x.device).long()
+ outs = self.unet(latents, t, context=c_crossattn)
+
+ return outs
diff --git a/mmseg/models/data_preprocessor.py b/mmseg/models/data_preprocessor.py
index deef365a9e8..8d32bc647b7 100644
--- a/mmseg/models/data_preprocessor.py
+++ b/mmseg/models/data_preprocessor.py
@@ -132,9 +132,9 @@ def forward(self, data: dict, training: bool = False) -> Dict[str, Any]:
inputs, data_samples = self.batch_augments(
inputs, data_samples)
else:
- assert len(inputs) == 1, (
- 'Batch inference is not support currently, '
- 'as the image size might be different in a batch')
+ img_size = inputs[0].shape[1:]
+ assert all(input_.shape[1:] == img_size for input_ in inputs), \
+ 'The image size in a batch should be the same.'
# pad images when testing
if self.test_cfg:
inputs, padded_samples = stack_batch(
diff --git a/mmseg/models/decode_heads/__init__.py b/mmseg/models/decode_heads/__init__.py
index 18235456bc9..4229763816e 100644
--- a/mmseg/models/decode_heads/__init__.py
+++ b/mmseg/models/decode_heads/__init__.py
@@ -4,6 +4,7 @@
from .aspp_head import ASPPHead
from .cc_head import CCHead
from .da_head import DAHead
+from .ddr_head import DDRHead
from .dm_head import DMHead
from .dnl_head import DNLHead
from .dpt_head import DPTHead
@@ -24,6 +25,7 @@
from .point_head import PointHead
from .psa_head import PSAHead
from .psp_head import PSPHead
+from .san_head import SideAdapterCLIPHead
from .segformer_head import SegformerHead
from .segmenter_mask_head import SegmenterMaskTransformerHead
from .sep_aspp_head import DepthwiseSeparableASPPHead
@@ -32,6 +34,7 @@
from .setr_up_head import SETRUPHead
from .stdc_head import STDCHead
from .uper_head import UPerHead
+from .vpd_depth_head import VPDDepthHead
__all__ = [
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
@@ -41,5 +44,5 @@
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead',
'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead',
'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead',
- 'LightHamHead', 'PIDHead'
+ 'LightHamHead', 'PIDHead', 'DDRHead', 'VPDDepthHead', 'SideAdapterCLIPHead'
]
diff --git a/mmseg/models/decode_heads/ddr_head.py b/mmseg/models/decode_heads/ddr_head.py
new file mode 100644
index 00000000000..ba26d6503c0
--- /dev/null
+++ b/mmseg/models/decode_heads/ddr_head.py
@@ -0,0 +1,116 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Tuple, Union
+
+import torch.nn as nn
+from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
+from torch import Tensor
+
+from mmseg.models.decode_heads.decode_head import BaseDecodeHead
+from mmseg.models.losses import accuracy
+from mmseg.models.utils import resize
+from mmseg.registry import MODELS
+from mmseg.utils import OptConfigType, SampleList
+
+
+@MODELS.register_module()
+class DDRHead(BaseDecodeHead):
+ """Decode head for DDRNet.
+
+ Args:
+ in_channels (int): Number of input channels.
+ channels (int): Number of output channels.
+ num_classes (int): Number of classes.
+ norm_cfg (dict, optional): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict, optional): Config dict for activation layer.
+ Default: dict(type='ReLU', inplace=True).
+ """
+
+ def __init__(self,
+ in_channels: int,
+ channels: int,
+ num_classes: int,
+ norm_cfg: OptConfigType = dict(type='BN'),
+ act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
+ **kwargs):
+ super().__init__(
+ in_channels,
+ channels,
+ num_classes=num_classes,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ **kwargs)
+
+ self.head = self._make_base_head(self.in_channels, self.channels)
+ self.aux_head = self._make_base_head(self.in_channels // 2,
+ self.channels)
+ self.aux_cls_seg = nn.Conv2d(
+ self.channels, self.out_channels, kernel_size=1)
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(
+ m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(
+ self,
+ inputs: Union[Tensor,
+ Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]:
+ if self.training:
+ c3_feat, c5_feat = inputs
+ x_c = self.head(c5_feat)
+ x_c = self.cls_seg(x_c)
+ x_s = self.aux_head(c3_feat)
+ x_s = self.aux_cls_seg(x_s)
+
+ return x_c, x_s
+ else:
+ x_c = self.head(inputs)
+ x_c = self.cls_seg(x_c)
+ return x_c
+
+ def _make_base_head(self, in_channels: int,
+ channels: int) -> nn.Sequential:
+ layers = [
+ ConvModule(
+ in_channels,
+ channels,
+ kernel_size=3,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ order=('norm', 'act', 'conv')),
+ build_norm_layer(self.norm_cfg, channels)[1],
+ build_activation_layer(self.act_cfg),
+ ]
+
+ return nn.Sequential(*layers)
+
+ def loss_by_feat(self, seg_logits: Tuple[Tensor],
+ batch_data_samples: SampleList) -> dict:
+ loss = dict()
+ context_logit, spatial_logit = seg_logits
+ seg_label = self._stack_batch_gt(batch_data_samples)
+
+ context_logit = resize(
+ context_logit,
+ size=seg_label.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ spatial_logit = resize(
+ spatial_logit,
+ size=seg_label.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ seg_label = seg_label.squeeze(1)
+
+ loss['loss_context'] = self.loss_decode[0](context_logit, seg_label)
+ loss['loss_spatial'] = self.loss_decode[1](spatial_logit, seg_label)
+ loss['acc_seg'] = accuracy(
+ context_logit, seg_label, ignore_index=self.ignore_index)
+
+ return loss
diff --git a/mmseg/models/decode_heads/decode_head.py b/mmseg/models/decode_heads/decode_head.py
index 8bdbb24a1cc..4faf54559dc 100644
--- a/mmseg/models/decode_heads/decode_head.py
+++ b/mmseg/models/decode_heads/decode_head.py
@@ -45,7 +45,7 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
in_channels (int|Sequence[int]): Input channels.
channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes.
- out_channels (int): Output channels of conv_seg.
+ out_channels (int): Output channels of conv_seg. Default: None.
threshold (float): Threshold for binary segmentation in the case of
`num_classes==1`. Default: None.
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
diff --git a/mmseg/models/decode_heads/ham_head.py b/mmseg/models/decode_heads/ham_head.py
index d80025f77d2..073d8011b05 100644
--- a/mmseg/models/decode_heads/ham_head.py
+++ b/mmseg/models/decode_heads/ham_head.py
@@ -5,6 +5,7 @@
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
+from mmengine.device import get_device
from mmseg.registry import MODELS
from ..utils import resize
@@ -52,7 +53,7 @@ def __init__(self,
self.rand_init = rand_init
- def _build_bases(self, B, S, D, R, cuda=False):
+ def _build_bases(self, B, S, D, R, device=None):
raise NotImplementedError
def local_step(self, x, bases, coef):
@@ -80,14 +81,13 @@ def forward(self, x, return_bases=False):
D = C // self.S
N = H * W
x = x.view(B * self.S, D, N)
- cuda = 'cuda' in str(x.device)
if not self.rand_init and not hasattr(self, 'bases'):
- bases = self._build_bases(1, self.S, D, self.R, cuda=cuda)
+ bases = self._build_bases(1, self.S, D, self.R, device=x.device)
self.register_buffer('bases', bases)
# (S, D, R) -> (B * S, D, R)
if self.rand_init:
- bases = self._build_bases(B, self.S, D, self.R, cuda=cuda)
+ bases = self._build_bases(B, self.S, D, self.R, device=x.device)
else:
bases = self.bases.repeat(B, 1, 1)
@@ -116,13 +116,11 @@ def __init__(self, args=dict()):
self.inv_t = 1
- def _build_bases(self, B, S, D, R, cuda=False):
+ def _build_bases(self, B, S, D, R, device=None):
"""Build bases in initialization."""
- if cuda:
- bases = torch.rand((B * S, D, R)).cuda()
- else:
- bases = torch.rand((B * S, D, R))
-
+ if device is None:
+ device = get_device()
+ bases = torch.rand((B * S, D, R)).to(device)
bases = F.normalize(bases, dim=1)
return bases
diff --git a/mmseg/models/decode_heads/san_head.py b/mmseg/models/decode_heads/san_head.py
new file mode 100644
index 00000000000..03dedf2e495
--- /dev/null
+++ b/mmseg/models/decode_heads/san_head.py
@@ -0,0 +1,733 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from functools import partial
+from typing import Dict, List, Tuple
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule, build_norm_layer
+from mmcv.cnn.bricks.transformer import BaseTransformerLayer
+from mmcv.ops import point_sample
+from mmengine.dist import all_reduce
+from mmengine.model.weight_init import (caffe2_xavier_init, normal_init,
+ trunc_normal_)
+from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
+from mmengine.structures import InstanceData
+from torch import Tensor
+from torch.nn import functional as F
+
+from mmseg.models.backbones.vit import TransformerEncoderLayer
+from mmseg.registry import MODELS
+from mmseg.utils import (ConfigType, MatchMasks, SampleList,
+ seg_data_to_instance_data)
+from ..utils import (MLP, LayerNorm2d, PatchEmbed, cross_attn_layer,
+ get_uncertain_point_coords_with_randomness, resize)
+from .decode_head import BaseDecodeHead
+
+
+class MLPMaskDecoder(nn.Module):
+ """Module for decoding query and visual features with MLP layers to
+ generate the attention biases and the mask proposals."""
+
+ def __init__(
+ self,
+ *,
+ in_channels: int,
+ total_heads: int = 1,
+ total_layers: int = 1,
+ embed_channels: int = 256,
+ mlp_channels: int = 256,
+ mlp_num_layers: int = 3,
+ rescale_attn_bias: bool = False,
+ ):
+ super().__init__()
+ self.total_heads = total_heads
+ self.total_layers = total_layers
+
+ dense_affine_func = partial(nn.Conv2d, kernel_size=1)
+ # Query Branch
+ self.query_mlp = MLP(in_channels, mlp_channels, embed_channels,
+ mlp_num_layers)
+ # Pixel Branch
+ self.pix_mlp = MLP(
+ in_channels,
+ mlp_channels,
+ embed_channels,
+ mlp_num_layers,
+ affine_func=dense_affine_func,
+ )
+ # Attention Bias Branch
+ self.attn_mlp = MLP(
+ in_channels,
+ mlp_channels,
+ embed_channels * self.total_heads * self.total_layers,
+ mlp_num_layers,
+ affine_func=dense_affine_func,
+ )
+ if rescale_attn_bias:
+ self.bias_scaling = nn.Linear(1, 1)
+ else:
+ self.bias_scaling = nn.Identity()
+
+ def forward(self, query: torch.Tensor,
+ x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ """Forward function.
+ Args:
+ query (Tensor): Query Tokens [B,N,C].
+ x (Tensor): Visual features [B,C,H,W]
+
+ Return:
+ mask_preds (Tensor): Mask proposals.
+ attn_bias (List[Tensor]): List of attention bias.
+ """
+ query = self.query_mlp(query)
+ pix = self.pix_mlp(x)
+ b, c, h, w = pix.shape
+ # preidict mask
+ mask_preds = torch.einsum('bqc,bchw->bqhw', query, pix)
+ # generate attn bias
+ attn = self.attn_mlp(x)
+ attn = attn.reshape(b, self.total_layers, self.total_heads, c, h, w)
+ attn_bias = torch.einsum('bqc,blnchw->blnqhw', query, attn)
+ attn_bias = self.bias_scaling(attn_bias[..., None]).squeeze(-1)
+ attn_bias = attn_bias.chunk(self.total_layers, dim=1)
+ attn_bias = [attn.squeeze(1) for attn in attn_bias]
+ return mask_preds, attn_bias
+
+
+class SideAdapterNetwork(nn.Module):
+ """Side Adapter Network for predicting mask proposals and attention bias.
+
+ Args:
+ in_channels (int): Number of input channels. Default: 3.
+ clip_channels (int): Number of channels of visual features.
+ Default: 768.
+ embed_dims (int): embedding dimension. Default: 240.
+ patch_size (int): The patch size. Default: 16.
+ patch_bias (bool): Whether use bias in patch embedding.
+ Default: True.
+ num_queries (int): Number of queries for mask proposals.
+ Default: 100.
+ fusion_index (List[int]): The layer number of the encode
+ transformer to fuse with the CLIP feature.
+ Default: [0, 1, 2, 3].
+ cfg_encoder (ConfigType): Configs for the encode layers.
+ cfg_decoder (ConfigType): Configs for the decode layers.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN').
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ clip_channels: int = 768,
+ embed_dims: int = 240,
+ patch_size: int = 16,
+ patch_bias: bool = True,
+ num_queries: int = 100,
+ fusion_index: list = [0, 1, 2, 3],
+ cfg_encoder: ConfigType = ...,
+ cfg_decoder: ConfigType = ...,
+ norm_cfg: dict = dict(type='LN'),
+ ):
+ super().__init__()
+
+ self.patch_embed = PatchEmbed(
+ in_channels=in_channels,
+ embed_dims=embed_dims,
+ conv_type='Conv2d',
+ kernel_size=patch_size,
+ stride=patch_size,
+ padding=0,
+ input_size=(640, 640),
+ bias=patch_bias,
+ norm_cfg=None,
+ init_cfg=None,
+ )
+ ori_h, ori_w = self.patch_embed.init_out_size
+ num_patches = ori_h * ori_w
+ self.pos_embed = nn.Parameter(
+ torch.randn(1, num_patches, embed_dims) * .02)
+ self.query_pos_embed = nn.Parameter(
+ torch.zeros(1, num_queries, embed_dims))
+ self.query_embed = nn.Parameter(
+ torch.zeros(1, num_queries, embed_dims))
+ encode_layers = []
+ for i in range(cfg_encoder.num_encode_layer):
+ encode_layers.append(
+ TransformerEncoderLayer(
+ embed_dims=embed_dims,
+ num_heads=cfg_encoder.num_heads,
+ feedforward_channels=cfg_encoder.mlp_ratio * embed_dims,
+ norm_cfg=norm_cfg))
+ self.encode_layers = nn.ModuleList(encode_layers)
+ conv_clips = []
+ for i in range(len(fusion_index)):
+ conv_clips.append(
+ nn.Sequential(
+ LayerNorm2d(clip_channels),
+ ConvModule(
+ clip_channels,
+ embed_dims,
+ kernel_size=1,
+ norm_cfg=None,
+ act_cfg=None)))
+ self.conv_clips = nn.ModuleList(conv_clips)
+ self.fusion_index = fusion_index
+ self.mask_decoder = MLPMaskDecoder(
+ in_channels=embed_dims,
+ total_heads=cfg_decoder.num_heads,
+ total_layers=cfg_decoder.num_layers,
+ embed_channels=cfg_decoder.embed_channels,
+ mlp_channels=cfg_decoder.mlp_channels,
+ mlp_num_layers=cfg_decoder.num_mlp,
+ rescale_attn_bias=cfg_decoder.rescale)
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.query_embed, std=0.02)
+ nn.init.normal_(self.query_pos_embed, std=0.02)
+ for i in range(len(self.conv_clips)):
+ caffe2_xavier_init(self.conv_clips[i][1].conv)
+
+ def fuse_clip(self, fused_index: int, x: torch.Tensor,
+ clip_feature: torch.Tensor, hwshape: Tuple[int,
+ int], L: int):
+ """Fuse CLIP feature and visual tokens."""
+ fused_clip = (resize(
+ self.conv_clips[fused_index](clip_feature.contiguous()),
+ size=hwshape,
+ mode='bilinear',
+ align_corners=False)).permute(0, 2, 3, 1).reshape(x[:, -L:,
+ ...].shape)
+ x = torch.cat([x[:, :-L, ...], x[:, -L:, ...] + fused_clip], dim=1)
+ return x
+
+ def encode_feature(self, image: torch.Tensor,
+ clip_features: List[torch.Tensor],
+ deep_supervision_idxs: List[int]) -> List[List]:
+ """Encode images by a lightweight vision transformer."""
+ assert len(self.fusion_index) == len(clip_features)
+ x, hwshape = self.patch_embed(image)
+ ori_h, ori_w = self.patch_embed.init_out_size
+ pos_embed = self.pos_embed
+ if self.pos_embed.shape[1] != x.shape[1]:
+ # resize the position embedding
+ pos_embed = (
+ resize(
+ self.pos_embed.reshape(1, ori_h, ori_w,
+ -1).permute(0, 3, 1, 2),
+ size=hwshape,
+ mode='bicubic',
+ align_corners=False,
+ ).flatten(2).permute(0, 2, 1))
+ pos_embed = torch.cat([
+ self.query_pos_embed.expand(pos_embed.shape[0], -1, -1), pos_embed
+ ],
+ dim=1)
+ x = torch.cat([self.query_embed.expand(x.shape[0], -1, -1), x], dim=1)
+ x = x + pos_embed
+ L = hwshape[0] * hwshape[1]
+ fused_index = 0
+ if self.fusion_index[fused_index] == 0:
+ x = self.fuse_clip(fused_index, x, clip_features[0][0], hwshape, L)
+ fused_index += 1
+ outs = []
+ for index, block in enumerate(self.encode_layers, start=1):
+ x = block(x)
+ if index < len(self.fusion_index
+ ) and index == self.fusion_index[fused_index]:
+ x = self.fuse_clip(fused_index, x,
+ clip_features[fused_index][0], hwshape, L)
+ fused_index += 1
+ x_query = x[:, :-L, ...]
+ x_feat = x[:, -L:, ...].permute(0, 2, 1)\
+ .reshape(x.shape[0], x.shape[-1], hwshape[0], hwshape[1])
+
+ if index in deep_supervision_idxs or index == len(
+ self.encode_layers):
+ outs.append({'query': x_query, 'x': x_feat})
+
+ if index < len(self.encode_layers):
+ x = x + pos_embed
+ return outs
+
+ def decode_feature(self, features):
+ mask_embeds = []
+ attn_biases = []
+ for feature in features:
+ mask_embed, attn_bias = self.mask_decoder(**feature)
+ mask_embeds.append(mask_embed)
+ attn_biases.append(attn_bias)
+ return mask_embeds, attn_biases
+
+ def forward(
+ self, image: torch.Tensor, clip_features: List[torch.Tensor],
+ deep_supervision_idxs: List[int]
+ ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]:
+ """Forward function."""
+ features = self.encode_feature(image, clip_features,
+ deep_supervision_idxs)
+ mask_embeds, attn_biases = self.decode_feature(features)
+ return mask_embeds, attn_biases
+
+
+class RecWithAttnbias(nn.Module):
+ """Mask recognition module by applying the attention biases to rest deeper
+ CLIP layers.
+
+ Args:
+ sos_token_format (str): The format of sos token. It should be
+ chosen from ["cls_token", "learnable_token", "pos_embedding"].
+ Default: 'cls_token'.
+ sos_token_num (int): Number of sos token. It should be equal to
+ the number of quries. Default: 100.
+ num_layers (int): Number of rest CLIP layers for mask recognition.
+ Default: 3.
+ cross_attn (bool): Whether use cross attention to update sos token.
+ Default: False.
+ embed_dims (int): The feature dimension of CLIP layers.
+ Default: 768.
+ num_heads (int): Parallel attention heads of CLIP layers.
+ Default: 768.
+ mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
+ Default: 4.
+ qkv_bias (bool): Whether to use bias in multihead-attention.
+ Default: True.
+ out_dims (int): Number of channels of the output mask proposals.
+ It should be equal to the out_dims of text_encoder.
+ Default: 512.
+ final_norm (True): Whether use norm layer for sos token.
+ act_cfg (dict): The activation config for FFNs.
+ Default: dict(type='GELU').
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN').
+ frozen_exclude (List): List of parameters that are not to be frozen.
+ """
+
+ def __init__(self,
+ sos_token_format: str = 'cls_token',
+ sos_token_num: int = 100,
+ num_layers: int = 3,
+ cross_attn: bool = False,
+ embed_dims: int = 768,
+ num_heads: int = 12,
+ mlp_ratio: int = 4,
+ num_fcs: int = 2,
+ qkv_bias: bool = True,
+ out_dims: int = 512,
+ final_norm: bool = True,
+ act_cfg: dict = dict(type='GELU'),
+ norm_cfg: dict = dict(type='LN'),
+ frozen_exclude: List = []):
+ super().__init__()
+
+ assert sos_token_format in [
+ 'cls_token', 'learnable_token', 'pos_embedding'
+ ]
+ self.sos_token_format = sos_token_format
+ self.sos_token_num = sos_token_num
+ self.frozen_exclude = frozen_exclude
+ self.cross_attn = cross_attn
+ self.num_layers = num_layers
+ self.num_heads = num_heads
+ if sos_token_format in ['learnable_token', 'pos_embedding']:
+ self.sos_token = nn.Parameter(
+ torch.randn(sos_token_num, 1, self.proj.shape[0]))
+ self.frozen.append('sos_token')
+
+ layers = []
+ for i in range(num_layers):
+ layers.append(
+ BaseTransformerLayer(
+ attn_cfgs=dict(
+ type='MultiheadAttention',
+ embed_dims=embed_dims,
+ num_heads=num_heads,
+ batch_first=False,
+ bias=qkv_bias),
+ ffn_cfgs=dict(
+ type='FFN',
+ embed_dims=embed_dims,
+ feedforward_channels=mlp_ratio * embed_dims,
+ act_cfg=act_cfg),
+ operation_order=('norm', 'self_attn', 'norm', 'ffn')))
+ self.layers = nn.ModuleList(layers)
+
+ self.ln_post = build_norm_layer(norm_cfg, embed_dims)[1]
+ self.proj = nn.Linear(embed_dims, out_dims, bias=False)
+
+ self.final_norm = final_norm
+ self._freeze()
+
+ def init_weights(self, rec_state_dict):
+ if hasattr(self, 'sos_token'):
+ normal_init(self.sos_token, std=0.02)
+ if rec_state_dict is not None:
+ load_state_dict(self, rec_state_dict, strict=False, logger=None)
+ else:
+ super().init_weights()
+
+ def _freeze(self):
+ if 'all' in self.frozen_exclude:
+ return
+ for name, param in self.named_parameters():
+ if not any([exclude in name for exclude in self.frozen_exclude]):
+ param.requires_grad = False
+
+ def _build_attn_biases(self, attn_biases, target_shape):
+ formatted_attn_biases = []
+ for attn_bias in attn_biases:
+ # convert it to proper format: N*num_head,L,L
+ # attn_bias: [N, num_head/1, num_sos,H,W]
+ n, num_head, num_sos, h, w = attn_bias.shape
+ # reshape and downsample
+ attn_bias = F.adaptive_max_pool2d(
+ attn_bias.reshape(n, num_head * num_sos, h, w),
+ output_size=target_shape)
+ attn_bias = attn_bias.reshape(n, num_head, num_sos, *target_shape)
+
+ true_num_head = self.num_heads
+ assert (num_head == 1 or num_head
+ == true_num_head), f'num_head={num_head} is not supported.'
+ if num_head == 1:
+ attn_bias = attn_bias.repeat(1, true_num_head, 1, 1, 1)
+ attn_bias = attn_bias.reshape(n * true_num_head, num_sos, -1)
+ L = attn_bias.shape[-1]
+ if self.cross_attn:
+ # [n*num_head, num_sos, L]
+ formatted_attn_biases.append(attn_bias)
+ else:
+ # [n*num_head, num_sos+1+L, num_sos+1+L]
+ new_attn_bias = attn_bias.new_zeros(num_sos + 1 + L,
+ num_sos + 1 + L)
+ new_attn_bias[:, :num_sos] = -100
+ new_attn_bias[torch.arange(num_sos), torch.arange(num_sos)] = 0
+ new_attn_bias[:num_sos, num_sos] = -100
+ new_attn_bias = (
+ new_attn_bias[None, ...].expand(n * true_num_head, -1,
+ -1).clone())
+ new_attn_bias[..., :num_sos, -L:] = attn_bias
+ formatted_attn_biases.append(new_attn_bias)
+
+ if len(formatted_attn_biases) == 1:
+ formatted_attn_biases = [
+ formatted_attn_biases[0] for _ in range(self.num_layers)
+ ]
+ return formatted_attn_biases
+
+ def forward(self, bias: List[Tensor], feature: List[Tensor]):
+ """Forward function to recognize the category of masks
+ Args:
+ bias (List[Tensor]): Attention bias for transformer layers
+ feature (List[Tensor]): Output of the image encoder,
+ including cls_token and img_feature.
+ """
+ cls_token = feature[1].unsqueeze(0)
+ img_feature = feature[0]
+ b, c, h, w = img_feature.shape
+ # construct clip shadow features
+ x = torch.cat(
+ [cls_token,
+ img_feature.reshape(b, c, -1).permute(2, 0, 1)])
+
+ # construct sos token
+ if self.sos_token_format == 'cls_token':
+ sos_token = cls_token.repeat(self.sos_token_num, 1, 1)
+ elif self.sos_token_format == 'learnable_token':
+ sos_token = self.sos_token.expand(-1, b, -1)
+ elif self.sos_token_format == 'pos_embedding':
+ sos_token = self.sos_token.expand(-1, b, -1) + cls_token
+
+ # construct attn bias
+ attn_biases = self._build_attn_biases(bias, target_shape=(h, w))
+
+ if self.cross_attn:
+ for i, block in enumerate(self.layers):
+ if self.cross_attn:
+ sos_token = cross_attn_layer(
+ block,
+ sos_token,
+ x[1:, ],
+ attn_biases[i],
+ )
+ if i < len(self.layers) - 1:
+ x = block(x)
+ else:
+ x = torch.cat([sos_token, x], dim=0)
+ for i, block in enumerate(self.layers):
+ x = block(x, attn_masks=[attn_biases[i]])
+ sos_token = x[:self.sos_token_num]
+
+ sos_token = sos_token.permute(1, 0, 2) # LND -> NLD
+ sos_token = self.ln_post(sos_token)
+ sos_token = self.proj(sos_token)
+ if self.final_norm:
+ sos_token = F.normalize(sos_token, dim=-1)
+ return sos_token
+
+
+@MODELS.register_module()
+class SideAdapterCLIPHead(BaseDecodeHead):
+ """Side Adapter Network (SAN) for open-vocabulary semantic segmentation
+ with pre-trained vision-language model.
+
+ This decode head is the implementation of `Side Adapter Network
+ for Open-Vocabulary Semantic Segmentation`
+ .
+ Modified from https://github.com/MendelXu/SAN/blob/main/san/model/side_adapter/side_adapter.py # noqa:E501
+ Copyright (c) 2023 MendelXu.
+ Licensed under the MIT License
+
+ Args:
+ num_classes (int): the number of classes.
+ san_cfg (ConfigType): Configs for SideAdapterNetwork module
+ maskgen_cfg (ConfigType): Configs for RecWithAttnbias module
+ """
+
+ def __init__(self, num_classes: int, san_cfg: ConfigType,
+ maskgen_cfg: ConfigType, deep_supervision_idxs: List[int],
+ train_cfg: ConfigType, **kwargs):
+ super().__init__(
+ in_channels=san_cfg.in_channels,
+ channels=san_cfg.embed_dims,
+ num_classes=num_classes,
+ **kwargs)
+ assert san_cfg.num_queries == maskgen_cfg.sos_token_num, \
+ 'num_queries in san_cfg should be equal to sos_token_num ' \
+ 'in maskgen_cfg'
+ del self.conv_seg
+ self.side_adapter_network = SideAdapterNetwork(**san_cfg)
+ self.rec_with_attnbias = RecWithAttnbias(**maskgen_cfg)
+ self.deep_supervision_idxs = deep_supervision_idxs
+ self.train_cfg = train_cfg
+ if train_cfg:
+ self.match_masks = MatchMasks(
+ num_points=train_cfg.num_points,
+ num_queries=san_cfg.num_queries,
+ num_classes=num_classes,
+ assigner=train_cfg.assigner)
+
+ def init_weights(self):
+
+ rec_state_dict = None
+ if isinstance(self.init_cfg, dict) and \
+ self.init_cfg.get('type') == 'Pretrained_Part':
+ checkpoint = CheckpointLoader.load_checkpoint(
+ self.init_cfg['checkpoint'], logger=None, map_location='cpu')
+
+ rec_state_dict = checkpoint.copy()
+ para_prefix = 'decode_head.rec_with_attnbias'
+ prefix_len = len(para_prefix) + 1
+ for k, v in checkpoint.items():
+ rec_state_dict.pop(k)
+ if para_prefix in k:
+ rec_state_dict[k[prefix_len:]] = v
+
+ self.side_adapter_network.init_weights()
+ self.rec_with_attnbias.init_weights(rec_state_dict)
+
+ def forward(self, inputs: Tuple[Tensor],
+ deep_supervision_idxs) -> Tuple[List]:
+ """Forward function.
+
+ Args:
+ inputs (Tuple[Tensor]): A triplet including images,
+ list of multi-level visual features from image encoder and
+ class embeddings from text_encoder.
+
+ Returns:
+ mask_props (List[Tensor]): Mask proposals predicted by SAN.
+ mask_logits (List[Tensor]): Class logits of mask proposals.
+ """
+ imgs, clip_feature, class_embeds = inputs
+ # predict mask proposals and attention bias
+ mask_props, attn_biases = self.side_adapter_network(
+ imgs, clip_feature, deep_supervision_idxs)
+
+ # mask recognition with attention bias
+ mask_embeds = [
+ self.rec_with_attnbias(att_bias, clip_feature[-1])
+ for att_bias in attn_biases
+ ]
+ # Obtain class prediction of masks by comparing the similarity
+ # between the image token and the text embedding of class names.
+ mask_logits = [
+ torch.einsum('bqc,nc->bqn', mask_embed, class_embeds)
+ for mask_embed in mask_embeds
+ ]
+ return mask_props, mask_logits
+
+ def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
+ test_cfg: ConfigType) -> Tensor:
+ """Forward function for prediction.
+
+ Args:
+ inputs (Tuple[Tensor]): Images, visual features from image encoder
+ and class embedding from text encoder.
+ batch_img_metas (dict): List Image info where each dict may also
+ contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
+ 'ori_shape', and 'pad_shape'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
+ test_cfg (dict): The testing config.
+
+ Returns:
+ Tensor: Outputs segmentation logits map.
+ """
+ mask_props, mask_logits = self.forward(inputs, [])
+
+ return self.predict_by_feat([mask_props[-1], mask_logits[-1]],
+ batch_img_metas)
+
+ def predict_by_feat(self, seg_logits: List[Tensor],
+ batch_img_metas: List[dict]) -> Tensor:
+ """1. Transform a batch of mask proposals to the input shape.
+ 2. Generate segmentation map with mask proposals and class logits.
+ """
+ mask_pred = seg_logits[0]
+ cls_score = seg_logits[1]
+ if 'pad_shape' in batch_img_metas[0]:
+ size = batch_img_metas[0]['pad_shape']
+ else:
+ size = batch_img_metas[0]['img_shape']
+ # upsample mask
+ mask_pred = F.interpolate(
+ mask_pred, size=size, mode='bilinear', align_corners=False)
+
+ mask_cls = F.softmax(cls_score, dim=-1)[..., :-1]
+ mask_pred = mask_pred.sigmoid()
+ seg_logits = torch.einsum('bqc,bqhw->bchw', mask_cls, mask_pred)
+ return seg_logits
+
+ def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
+ train_cfg: ConfigType) -> dict:
+ """Perform forward propagation and loss calculation of the decoder head
+ on the features of the upstream network.
+
+ Args:
+ x (tuple[Tensor]): Multi-level features from the upstream
+ network, each is a 4D-tensor.
+ batch_data_samples (List[:obj:`SegDataSample`]): The Data
+ Samples. It usually includes information such as
+ `gt_sem_seg`.
+ train_cfg (ConfigType): Training config.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components.
+ """
+ # batch SegDataSample to InstanceDataSample
+ batch_gt_instances = seg_data_to_instance_data(self.ignore_index,
+ batch_data_samples)
+
+ # forward
+ all_mask_props, all_mask_logits = self.forward(
+ x, self.deep_supervision_idxs)
+
+ # loss
+ losses = self.loss_by_feat(all_mask_logits, all_mask_props,
+ batch_gt_instances)
+
+ return losses
+
+ def loss_by_feat(
+ self, all_cls_scores: Tensor, all_mask_preds: Tensor,
+ batch_gt_instances: List[InstanceData]) -> Dict[str, Tensor]:
+ """Loss function.
+
+ Args:
+ all_cls_scores (Tensor): Classification scores for all decoder
+ layers with shape (num_decoder, batch_size, num_queries,
+ cls_out_channels). Note `cls_out_channels` should includes
+ background.
+ all_mask_preds (Tensor): Mask scores for all decoder layers with
+ shape (num_decoder, batch_size, num_queries, h, w).
+ batch_gt_instances (list[obj:`InstanceData`]): each contains
+ ``labels`` and ``masks``.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ num_dec_layers = len(all_cls_scores)
+ batch_gt_instances_list = [
+ batch_gt_instances for _ in range(num_dec_layers)
+ ]
+
+ losses = []
+ for i in range(num_dec_layers):
+ cls_scores = all_cls_scores[i]
+ mask_preds = all_mask_preds[i]
+ # matching N mask predictions to K category labels
+ (labels, mask_targets, mask_weights,
+ avg_factor) = self.match_masks.get_targets(
+ cls_scores, mask_preds, batch_gt_instances_list[i])
+ cls_scores = cls_scores.flatten(0, 1)
+ labels = labels.flatten(0, 1)
+ num_total_masks = cls_scores.new_tensor([avg_factor],
+ dtype=torch.float)
+ all_reduce(num_total_masks, op='mean')
+ num_total_masks = max(num_total_masks, 1)
+
+ # extract positive ones
+ # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
+ mask_preds = mask_preds[mask_weights > 0]
+
+ if mask_targets.shape[0] != 0:
+ with torch.no_grad():
+ points_coords = get_uncertain_point_coords_with_randomness(
+ mask_preds.unsqueeze(1), None,
+ self.train_cfg.num_points,
+ self.train_cfg.oversample_ratio,
+ self.train_cfg.importance_sample_ratio)
+ # shape (num_total_gts, h, w)
+ # -> (num_total_gts, num_points)
+ mask_point_targets = point_sample(
+ mask_targets.unsqueeze(1).float(),
+ points_coords).squeeze(1)
+ # shape (num_queries, h, w) -> (num_queries, num_points)
+ mask_point_preds = point_sample(
+ mask_preds.unsqueeze(1), points_coords).squeeze(1)
+
+ if not isinstance(self.loss_decode, nn.ModuleList):
+ losses_decode = [self.loss_decode]
+ else:
+ losses_decode = self.loss_decode
+ loss = dict()
+ for loss_decode in losses_decode:
+ if 'loss_cls' in loss_decode.loss_name:
+ if loss_decode.loss_name == 'loss_cls_ce':
+ loss[loss_decode.loss_name] = loss_decode(
+ cls_scores, labels)
+ else:
+ assert False, "Only support 'CrossEntropyLoss' in" \
+ ' classification loss'
+
+ elif 'loss_mask' in loss_decode.loss_name:
+ if mask_targets.shape[0] == 0:
+ loss[loss_decode.loss_name] = mask_preds.sum()
+ elif loss_decode.loss_name == 'loss_mask_ce':
+ loss[loss_decode.loss_name] = loss_decode(
+ mask_point_preds,
+ mask_point_targets,
+ avg_factor=num_total_masks *
+ self.train_cfg.num_points)
+ elif loss_decode.loss_name == 'loss_mask_dice':
+ loss[loss_decode.loss_name] = loss_decode(
+ mask_point_preds,
+ mask_point_targets,
+ avg_factor=num_total_masks)
+ else:
+ assert False, "Only support 'CrossEntropyLoss' and" \
+ " 'DiceLoss' in mask loss"
+ else:
+ assert False, "Only support for 'loss_cls' and 'loss_mask'"
+
+ losses.append(loss)
+
+ loss_dict = dict()
+ # loss from the last decoder layer
+ loss_dict.update(losses[-1])
+ # loss from other decoder layers
+ for i, loss in enumerate(losses[:-1]):
+ for k, v in loss.items():
+ loss_dict[f'd{self.deep_supervision_idxs[i]}.{k}'] = v
+ return loss_dict
diff --git a/mmseg/models/decode_heads/vpd_depth_head.py b/mmseg/models/decode_heads/vpd_depth_head.py
new file mode 100644
index 00000000000..0c54c2da1b1
--- /dev/null
+++ b/mmseg/models/decode_heads/vpd_depth_head.py
@@ -0,0 +1,254 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, List, Optional, Sequence, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer
+from mmengine.model import BaseModule
+from torch import Tensor
+
+from mmseg.registry import MODELS
+from mmseg.utils import SampleList
+from ..builder import build_loss
+from ..utils import resize
+from .decode_head import BaseDecodeHead
+
+
+class VPDDepthDecoder(BaseModule):
+ """VPD Depth Decoder class.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ num_deconv_layers (int): Number of deconvolution layers.
+ num_deconv_filters (List[int]): List of output channels for
+ deconvolution layers.
+ init_cfg (Optional[Union[Dict, List[Dict]]], optional): Configuration
+ for weight initialization. Defaults to Normal for Conv2d and
+ ConvTranspose2d layers.
+ """
+
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ num_deconv_layers: int,
+ num_deconv_filters: List[int],
+ init_cfg: Optional[Union[Dict, List[Dict]]] = dict(
+ type='Normal',
+ std=0.001,
+ layer=['Conv2d', 'ConvTranspose2d'])):
+ super().__init__(init_cfg=init_cfg)
+ self.in_channels = in_channels
+
+ self.deconv_layers = self._make_deconv_layer(
+ num_deconv_layers,
+ num_deconv_filters,
+ )
+
+ conv_layers = []
+ conv_layers.append(
+ build_conv_layer(
+ dict(type='Conv2d'),
+ in_channels=num_deconv_filters[-1],
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1))
+ conv_layers.append(build_norm_layer(dict(type='BN'), out_channels)[1])
+ conv_layers.append(nn.ReLU(inplace=True))
+ self.conv_layers = nn.Sequential(*conv_layers)
+
+ self.up_sample = nn.Upsample(
+ scale_factor=2, mode='bilinear', align_corners=False)
+
+ def forward(self, x):
+ """Forward pass through the decoder network."""
+ out = self.deconv_layers(x)
+ out = self.conv_layers(out)
+
+ out = self.up_sample(out)
+ out = self.up_sample(out)
+
+ return out
+
+ def _make_deconv_layer(self, num_layers, num_deconv_filters):
+ """Make deconv layers."""
+
+ layers = []
+ in_channels = self.in_channels
+ for i in range(num_layers):
+
+ num_channels = num_deconv_filters[i]
+ layers.append(
+ build_upsample_layer(
+ dict(type='deconv'),
+ in_channels=in_channels,
+ out_channels=num_channels,
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ output_padding=0,
+ bias=False))
+ layers.append(nn.BatchNorm2d(num_channels))
+ layers.append(nn.ReLU(inplace=True))
+ in_channels = num_channels
+
+ return nn.Sequential(*layers)
+
+
+@MODELS.register_module()
+class VPDDepthHead(BaseDecodeHead):
+ """Depth Prediction Head for VPD.
+
+ .. _`VPD`: https://arxiv.org/abs/2303.02153
+
+ Args:
+ max_depth (float): Maximum depth value. Defaults to 10.0.
+ in_channels (Sequence[int]): Number of input channels for each
+ convolutional layer.
+ embed_dim (int): Dimension of embedding. Defaults to 192.
+ feature_dim (int): Dimension of aggregated feature. Defaults to 1536.
+ num_deconv_layers (int): Number of deconvolution layers in the
+ decoder. Defaults to 3.
+ num_deconv_filters (Sequence[int]): Number of filters for each deconv
+ layer. Defaults to (32, 32, 32).
+ fmap_border (Union[int, Sequence[int]]): Feature map border for
+ cropping. Defaults to 0.
+ align_corners (bool): Flag for align_corners in interpolation.
+ Defaults to False.
+ loss_decode (dict): Configurations for the loss function. Defaults to
+ dict(type='SiLogLoss').
+ init_cfg (dict): Initialization configurations. Defaults to
+ dict(type='TruncNormal', std=0.02, layer=['Conv2d', 'Linear']).
+ """
+
+ num_classes = 1
+ out_channels = 1
+ input_transform = None
+
+ def __init__(
+ self,
+ max_depth: float = 10.0,
+ in_channels: Sequence[int] = [320, 640, 1280, 1280],
+ embed_dim: int = 192,
+ feature_dim: int = 1536,
+ num_deconv_layers: int = 3,
+ num_deconv_filters: Sequence[int] = (32, 32, 32),
+ fmap_border: Union[int, Sequence[int]] = 0,
+ align_corners: bool = False,
+ loss_decode: dict = dict(type='SiLogLoss'),
+ init_cfg=dict(
+ type='TruncNormal', std=0.02, layer=['Conv2d', 'Linear']),
+ ):
+
+ super(BaseDecodeHead, self).__init__(init_cfg=init_cfg)
+
+ # initialize parameters
+ self.in_channels = in_channels
+ self.max_depth = max_depth
+ self.align_corners = align_corners
+
+ # feature map border
+ if isinstance(fmap_border, int):
+ fmap_border = (fmap_border, fmap_border)
+ self.fmap_border = fmap_border
+
+ # define network layers
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_channels[0], in_channels[0], 3, stride=2, padding=1),
+ nn.GroupNorm(16, in_channels[0]),
+ nn.ReLU(),
+ nn.Conv2d(in_channels[0], in_channels[0], 3, stride=2, padding=1),
+ )
+ self.conv2 = nn.Conv2d(
+ in_channels[1], in_channels[1], 3, stride=2, padding=1)
+
+ self.conv_aggregation = nn.Sequential(
+ nn.Conv2d(sum(in_channels), feature_dim, 1),
+ nn.GroupNorm(16, feature_dim),
+ nn.ReLU(),
+ )
+
+ self.decoder = VPDDepthDecoder(
+ in_channels=embed_dim * 8,
+ out_channels=embed_dim,
+ num_deconv_layers=num_deconv_layers,
+ num_deconv_filters=num_deconv_filters)
+
+ self.depth_pred_layer = nn.Sequential(
+ nn.Conv2d(
+ embed_dim, embed_dim, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=False),
+ nn.Conv2d(embed_dim, 1, kernel_size=3, stride=1, padding=1))
+
+ # build loss
+ if isinstance(loss_decode, dict):
+ self.loss_decode = build_loss(loss_decode)
+ elif isinstance(loss_decode, (list, tuple)):
+ self.loss_decode = nn.ModuleList()
+ for loss in loss_decode:
+ self.loss_decode.append(build_loss(loss))
+ else:
+ raise TypeError(f'loss_decode must be a dict or sequence of dict,\
+ but got {type(loss_decode)}')
+
+ def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
+ gt_depth_maps = [
+ data_sample.gt_depth_map.data for data_sample in batch_data_samples
+ ]
+ return torch.stack(gt_depth_maps, dim=0)
+
+ def forward(self, x):
+ x = [
+ x[0], x[1],
+ torch.cat([x[2], F.interpolate(x[3], scale_factor=2)], dim=1)
+ ]
+ x = torch.cat([self.conv1(x[0]), self.conv2(x[1]), x[2]], dim=1)
+ x = self.conv_aggregation(x)
+
+ x = x[:, :, :x.size(2) - self.fmap_border[0], :x.size(3) -
+ self.fmap_border[1]].contiguous()
+ x = self.decoder(x)
+ out = self.depth_pred_layer(x)
+
+ depth = torch.sigmoid(out) * self.max_depth
+
+ return depth
+
+ def loss_by_feat(self, pred_depth_map: Tensor,
+ batch_data_samples: SampleList) -> dict:
+ """Compute depth estimation loss.
+
+ Args:
+ pred_depth_map (Tensor): The output from decode head forward
+ function.
+ batch_data_samples (List[:obj:`SegDataSample`]): The seg
+ data samples. It usually includes information such
+ as `metainfo` and `gt_dpeth_map`.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+
+ gt_depth_map = self._stack_batch_gt(batch_data_samples)
+ loss = dict()
+ pred_depth_map = resize(
+ input=pred_depth_map,
+ size=gt_depth_map.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+
+ if not isinstance(self.loss_decode, nn.ModuleList):
+ losses_decode = [self.loss_decode]
+ else:
+ losses_decode = self.loss_decode
+ for loss_decode in losses_decode:
+ if loss_decode.loss_name not in loss:
+ loss[loss_decode.loss_name] = loss_decode(
+ pred_depth_map, gt_depth_map)
+ else:
+ loss[loss_decode.loss_name] += loss_decode(
+ pred_depth_map, gt_depth_map)
+
+ return loss
diff --git a/mmseg/models/losses/__init__.py b/mmseg/models/losses/__init__.py
index 2f7e39cb28b..0467cb3ad89 100644
--- a/mmseg/models/losses/__init__.py
+++ b/mmseg/models/losses/__init__.py
@@ -5,8 +5,10 @@
cross_entropy, mask_cross_entropy)
from .dice_loss import DiceLoss
from .focal_loss import FocalLoss
+from .huasdorff_distance_loss import HuasdorffDisstanceLoss
from .lovasz_loss import LovaszLoss
from .ohem_cross_entropy_loss import OhemCrossEntropy
+from .silog_loss import SiLogLoss
from .tversky_loss import TverskyLoss
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
@@ -14,5 +16,6 @@
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss',
- 'FocalLoss', 'TverskyLoss', 'OhemCrossEntropy', 'BoundaryLoss'
+ 'FocalLoss', 'TverskyLoss', 'OhemCrossEntropy', 'BoundaryLoss',
+ 'HuasdorffDisstanceLoss', 'SiLogLoss'
]
diff --git a/mmseg/models/losses/cross_entropy_loss.py b/mmseg/models/losses/cross_entropy_loss.py
index 770b9974861..65553472c0f 100644
--- a/mmseg/models/losses/cross_entropy_loss.py
+++ b/mmseg/models/losses/cross_entropy_loss.py
@@ -53,8 +53,22 @@ def cross_entropy(pred,
# average loss over non-ignored elements
# pytorch's official cross_entropy average loss over non-ignored elements
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
- if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
- avg_factor = label.numel() - (label == ignore_index).sum().item()
+ if (avg_factor is None) and reduction == 'mean':
+ if class_weight is None:
+ if avg_non_ignore:
+ avg_factor = label.numel() - (label
+ == ignore_index).sum().item()
+ else:
+ avg_factor = label.numel()
+
+ else:
+ # the average factor should take the class weights into account
+ label_weights = torch.tensor([class_weight[cls] for cls in label],
+ device=class_weight.device)
+ if avg_non_ignore:
+ label_weights[label == ignore_index] = 0
+ avg_factor = label_weights.sum()
+
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
@@ -124,7 +138,7 @@ def binary_cross_entropy(pred,
assert label[label != ignore_index].max() <= 1, \
'For pred with shape [N, 1, H, W], its label must have at ' \
'most 2 classes'
- pred = pred.squeeze()
+ pred = pred.squeeze(1)
if pred.dim() != label.dim():
assert (pred.dim() == 2 and label.dim() == 1) or (
pred.dim() == 4 and label.dim() == 3), \
diff --git a/mmseg/models/losses/dice_loss.py b/mmseg/models/losses/dice_loss.py
index 2ee89a81f4e..fb2ffdba8da 100644
--- a/mmseg/models/losses/dice_loss.py
+++ b/mmseg/models/losses/dice_loss.py
@@ -1,125 +1,190 @@
# Copyright (c) OpenMMLab. All rights reserved.
-"""Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/
-segmentron/solver/loss.py (Apache-2.0 License)"""
+from typing import Union
+
import torch
import torch.nn as nn
-import torch.nn.functional as F
from mmseg.registry import MODELS
-from .utils import get_class_weight, weighted_loss
-
-
-@weighted_loss
-def dice_loss(pred,
- target,
- valid_mask,
- smooth=1,
- exponent=2,
- class_weight=None,
- ignore_index=255):
- assert pred.shape[0] == target.shape[0]
- total_loss = 0
- num_classes = pred.shape[1]
- for i in range(num_classes):
- if i != ignore_index:
- dice_loss = binary_dice_loss(
- pred[:, i],
- target[..., i],
- valid_mask=valid_mask,
- smooth=smooth,
- exponent=exponent)
- if class_weight is not None:
- dice_loss *= class_weight[i]
- total_loss += dice_loss
- return total_loss / num_classes
+from .utils import weight_reduce_loss
-@weighted_loss
-def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards):
- assert pred.shape[0] == target.shape[0]
- pred = pred.reshape(pred.shape[0], -1)
- target = target.reshape(target.shape[0], -1)
- valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
+def _expand_onehot_labels_dice(pred: torch.Tensor,
+ target: torch.Tensor) -> torch.Tensor:
+ """Expand onehot labels to match the size of prediction.
- num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth
- den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth
+ Args:
+ pred (torch.Tensor): The prediction, has a shape (N, num_class, H, W).
+ target (torch.Tensor): The learning label of the prediction,
+ has a shape (N, H, W).
- return 1 - num / den
+ Returns:
+ torch.Tensor: The target after one-hot encoding,
+ has a shape (N, num_class, H, W).
+ """
+ num_classes = pred.shape[1]
+ one_hot_target = torch.clamp(target, min=0, max=num_classes)
+ one_hot_target = torch.nn.functional.one_hot(one_hot_target,
+ num_classes + 1)
+ one_hot_target = one_hot_target[..., :num_classes].permute(0, 3, 1, 2)
+ return one_hot_target
+
+
+def dice_loss(pred: torch.Tensor,
+ target: torch.Tensor,
+ weight: Union[torch.Tensor, None],
+ eps: float = 1e-3,
+ reduction: Union[str, None] = 'mean',
+ naive_dice: Union[bool, None] = False,
+ avg_factor: Union[int, None] = None,
+ ignore_index: Union[int, None] = 255) -> float:
+ """Calculate dice loss, there are two forms of dice loss is supported:
+
+ - the one proposed in `V-Net: Fully Convolutional Neural
+ Networks for Volumetric Medical Image Segmentation
+ `_.
+ - the dice loss in which the power of the number in the
+ denominator is the first power instead of the second
+ power.
+
+ Args:
+ pred (torch.Tensor): The prediction, has a shape (n, *)
+ target (torch.Tensor): The learning label of the prediction,
+ shape (n, *), same shape of pred.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction, has a shape (n,). Defaults to None.
+ eps (float): Avoid dividing by zero. Default: 1e-3.
+ reduction (str, optional): The method used to reduce the loss into
+ a scalar. Defaults to 'mean'.
+ Options are "none", "mean" and "sum".
+ naive_dice (bool, optional): If false, use the dice
+ loss defined in the V-Net paper, otherwise, use the
+ naive dice loss in which the power of the number in the
+ denominator is the first power instead of the second
+ power.Defaults to False.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ ignore_index (int, optional): The label index to be ignored.
+ Defaults to 255.
+ """
+ if ignore_index is not None:
+ num_classes = pred.shape[1]
+ pred = pred[:, torch.arange(num_classes) != ignore_index, :, :]
+ target = target[:, torch.arange(num_classes) != ignore_index, :, :]
+ assert pred.shape[1] != 0 # if the ignored index is the only class
+ input = pred.flatten(1)
+ target = target.flatten(1).float()
+ a = torch.sum(input * target, 1)
+ if naive_dice:
+ b = torch.sum(input, 1)
+ c = torch.sum(target, 1)
+ d = (2 * a + eps) / (b + c + eps)
+ else:
+ b = torch.sum(input * input, 1) + eps
+ c = torch.sum(target * target, 1) + eps
+ d = (2 * a) / (b + c)
+
+ loss = 1 - d
+ if weight is not None:
+ assert weight.ndim == loss.ndim
+ assert len(weight) == len(pred)
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
@MODELS.register_module()
class DiceLoss(nn.Module):
- """DiceLoss.
-
- This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
- Volumetric Medical Image Segmentation `_.
-
- Args:
- smooth (float): A float number to smooth loss, and avoid NaN error.
- Default: 1
- exponent (float): An float number to calculate denominator
- value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2.
- reduction (str, optional): The method used to reduce the loss. Options
- are "none", "mean" and "sum". This parameter only works when
- per_image is True. Default: 'mean'.
- class_weight (list[float] | str, optional): Weight of each class. If in
- str format, read them from a file. Defaults to None.
- loss_weight (float, optional): Weight of the loss. Default to 1.0.
- ignore_index (int | None): The label index to be ignored. Default: 255.
- loss_name (str, optional): Name of the loss item. If you want this loss
- item to be included into the backward graph, `loss_` must be the
- prefix of the name. Defaults to 'loss_dice'.
- """
def __init__(self,
- smooth=1,
- exponent=2,
+ use_sigmoid=True,
+ activate=True,
reduction='mean',
- class_weight=None,
+ naive_dice=False,
loss_weight=1.0,
ignore_index=255,
- loss_name='loss_dice',
- **kwards):
+ eps=1e-3,
+ loss_name='loss_dice'):
+ """Compute dice loss.
+
+ Args:
+ use_sigmoid (bool, optional): Whether to the prediction is
+ used for sigmoid or softmax. Defaults to True.
+ activate (bool): Whether to activate the predictions inside,
+ this will disable the inside sigmoid operation.
+ Defaults to True.
+ reduction (str, optional): The method used
+ to reduce the loss. Options are "none",
+ "mean" and "sum". Defaults to 'mean'.
+ naive_dice (bool, optional): If false, use the dice
+ loss defined in the V-Net paper, otherwise, use the
+ naive dice loss in which the power of the number in the
+ denominator is the first power instead of the second
+ power. Defaults to False.
+ loss_weight (float, optional): Weight of loss. Defaults to 1.0.
+ ignore_index (int, optional): The label index to be ignored.
+ Default: 255.
+ eps (float): Avoid dividing by zero. Defaults to 1e-3.
+ loss_name (str, optional): Name of the loss item. If you want this
+ loss item to be included into the backward graph, `loss_` must
+ be the prefix of the name. Defaults to 'loss_dice'.
+ """
+
super().__init__()
- self.smooth = smooth
- self.exponent = exponent
+ self.use_sigmoid = use_sigmoid
self.reduction = reduction
- self.class_weight = get_class_weight(class_weight)
+ self.naive_dice = naive_dice
self.loss_weight = loss_weight
+ self.eps = eps
+ self.activate = activate
self.ignore_index = ignore_index
self._loss_name = loss_name
def forward(self,
pred,
target,
+ weight=None,
avg_factor=None,
reduction_override=None,
- **kwards):
+ ignore_index=255,
+ **kwargs):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): The prediction, has a shape (n, *).
+ target (torch.Tensor): The label of the prediction,
+ shape (n, *), same shape of pred.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction, has a shape (n,). Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Options are "none", "mean" and "sum".
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ one_hot_target = target
+ if (pred.shape != target.shape):
+ one_hot_target = _expand_onehot_labels_dice(pred, target)
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
- if self.class_weight is not None:
- class_weight = pred.new_tensor(self.class_weight)
- else:
- class_weight = None
-
- pred = F.softmax(pred, dim=1)
- num_classes = pred.shape[1]
- one_hot_target = F.one_hot(
- torch.clamp(target.long(), 0, num_classes - 1),
- num_classes=num_classes)
- valid_mask = (target != self.ignore_index).long()
-
+ if self.activate:
+ if self.use_sigmoid:
+ pred = pred.sigmoid()
+ elif pred.shape[1] != 1:
+ # softmax does not work when there is only 1 class
+ pred = pred.softmax(dim=1)
loss = self.loss_weight * dice_loss(
pred,
one_hot_target,
- valid_mask=valid_mask,
+ weight,
+ eps=self.eps,
reduction=reduction,
+ naive_dice=self.naive_dice,
avg_factor=avg_factor,
- smooth=self.smooth,
- exponent=self.exponent,
- class_weight=class_weight,
ignore_index=self.ignore_index)
+
return loss
@property
diff --git a/mmseg/models/losses/focal_loss.py b/mmseg/models/losses/focal_loss.py
index 104d6602c80..6507ed7a911 100644
--- a/mmseg/models/losses/focal_loss.py
+++ b/mmseg/models/losses/focal_loss.py
@@ -271,7 +271,13 @@ def forward(self,
num_classes = pred.size(1)
if torch.cuda.is_available() and pred.is_cuda:
if target.dim() == 1:
- one_hot_target = F.one_hot(target, num_classes=num_classes)
+ one_hot_target = F.one_hot(
+ target, num_classes=num_classes + 1)
+ if num_classes == 1:
+ one_hot_target = one_hot_target[:, 1]
+ target = 1 - target
+ else:
+ one_hot_target = one_hot_target[:, :num_classes]
else:
one_hot_target = target
target = target.argmax(dim=1)
@@ -280,7 +286,11 @@ def forward(self,
else:
one_hot_target = None
if target.dim() == 1:
- target = F.one_hot(target, num_classes=num_classes)
+ target = F.one_hot(target, num_classes=num_classes + 1)
+ if num_classes == 1:
+ target = target[:, 1]
+ else:
+ target = target[:, num_classes]
else:
valid_mask = (target.argmax(dim=1) != ignore_index).view(
-1, 1)
diff --git a/mmseg/models/losses/huasdorff_distance_loss.py b/mmseg/models/losses/huasdorff_distance_loss.py
new file mode 100644
index 00000000000..d950ba728f8
--- /dev/null
+++ b/mmseg/models/losses/huasdorff_distance_loss.py
@@ -0,0 +1,160 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""Modified from https://github.com/JunMa11/SegWithDistMap/blob/
+master/code/train_LA_HD.py (Apache-2.0 License)"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from scipy.ndimage import distance_transform_edt as distance
+from torch import Tensor
+
+from mmseg.registry import MODELS
+from .utils import get_class_weight, weighted_loss
+
+
+def compute_dtm(img_gt: Tensor, pred: Tensor) -> Tensor:
+ """
+ compute the distance transform map of foreground in mask
+ Args:
+ img_gt: Ground truth of the image, (b, h, w)
+ pred: Predictions of the segmentation head after softmax, (b, c, h, w)
+
+ Returns:
+ output: the foreground Distance Map (SDM)
+ dtm(x) = 0; x in segmentation boundary
+ inf|x-y|; x in segmentation
+ """
+
+ fg_dtm = torch.zeros_like(pred)
+ out_shape = pred.shape
+ for b in range(out_shape[0]): # batch size
+ for c in range(1, out_shape[1]): # default 0 channel is background
+ posmask = img_gt[b].byte()
+ if posmask.any():
+ posdis = distance(posmask)
+ fg_dtm[b][c] = torch.from_numpy(posdis)
+
+ return fg_dtm
+
+
+@weighted_loss
+def hd_loss(seg_soft: Tensor,
+ gt: Tensor,
+ seg_dtm: Tensor,
+ gt_dtm: Tensor,
+ class_weight=None,
+ ignore_index=255) -> Tensor:
+ """
+ compute huasdorff distance loss for segmentation
+ Args:
+ seg_soft: softmax results, shape=(b,c,x,y)
+ gt: ground truth, shape=(b,x,y)
+ seg_dtm: segmentation distance transform map, shape=(b,c,x,y)
+ gt_dtm: ground truth distance transform map, shape=(b,c,x,y)
+
+ Returns:
+ output: hd_loss
+ """
+ assert seg_soft.shape[0] == gt.shape[0]
+ total_loss = 0
+ num_class = seg_soft.shape[1]
+ if class_weight is not None:
+ assert class_weight.ndim == num_class
+ for i in range(1, num_class):
+ if i != ignore_index:
+ delta_s = (seg_soft[:, i, ...] - gt.float())**2
+ s_dtm = seg_dtm[:, i, ...]**2
+ g_dtm = gt_dtm[:, i, ...]**2
+ dtm = s_dtm + g_dtm
+ multiplied = torch.einsum('bxy, bxy->bxy', delta_s, dtm)
+ hd_loss = multiplied.mean()
+ if class_weight is not None:
+ hd_loss *= class_weight[i]
+ total_loss += hd_loss
+
+ return total_loss / num_class
+
+
+@MODELS.register_module()
+class HuasdorffDisstanceLoss(nn.Module):
+ """HuasdorffDisstanceLoss. This loss is proposed in `How Distance Transform
+ Maps Boost Segmentation CNNs: An Empirical Study.
+
+ `_.
+ Args:
+ reduction (str, optional): The method used to reduce the loss into
+ a scalar. Defaults to 'mean'.
+ class_weight (list[float] | str, optional): Weight of each class. If in
+ str format, read them from a file. Defaults to None.
+ loss_weight (float): Weight of the loss. Defaults to 1.0.
+ ignore_index (int | None): The label index to be ignored. Default: 255.
+ loss_name (str): Name of the loss item. If you want this loss
+ item to be included into the backward graph, `loss_` must be the
+ prefix of the name. Defaults to 'loss_boundary'.
+ """
+
+ def __init__(self,
+ reduction='mean',
+ class_weight=None,
+ loss_weight=1.0,
+ ignore_index=255,
+ loss_name='loss_huasdorff_disstance',
+ **kwargs):
+ super().__init__()
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.class_weight = get_class_weight(class_weight)
+ self._loss_name = loss_name
+ self.ignore_index = ignore_index
+
+ def forward(self,
+ pred: Tensor,
+ target: Tensor,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs) -> Tensor:
+ """Forward function.
+
+ Args:
+ pred (Tensor): Predictions of the segmentation head. (B, C, H, W)
+ target (Tensor): Ground truth of the image. (B, H, W)
+ avg_factor (int, optional): Average factor that is used to
+ average the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used
+ to override the original reduction method of the loss.
+ Options are "none", "mean" and "sum".
+ Returns:
+ Tensor: Loss tensor.
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if self.class_weight is not None:
+ class_weight = pred.new_tensor(self.class_weight)
+ else:
+ class_weight = None
+
+ pred_soft = F.softmax(pred, dim=1)
+ valid_mask = (target != self.ignore_index).long()
+ target = target * valid_mask
+
+ with torch.no_grad():
+ gt_dtm = compute_dtm(target.cpu(), pred_soft)
+ gt_dtm = gt_dtm.float()
+ seg_dtm2 = compute_dtm(
+ pred_soft.argmax(dim=1, keepdim=False).cpu(), pred_soft)
+ seg_dtm2 = seg_dtm2.float()
+
+ loss_hd = self.loss_weight * hd_loss(
+ pred_soft,
+ target,
+ seg_dtm=seg_dtm2,
+ gt_dtm=gt_dtm,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ class_weight=class_weight,
+ ignore_index=self.ignore_index)
+ return loss_hd
+
+ @property
+ def loss_name(self):
+ return self._loss_name
diff --git a/mmseg/models/losses/kldiv_loss.py b/mmseg/models/losses/kldiv_loss.py
new file mode 100644
index 00000000000..496ef9713f0
--- /dev/null
+++ b/mmseg/models/losses/kldiv_loss.py
@@ -0,0 +1,99 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from mmseg.registry import MODELS
+
+
+@MODELS.register_module()
+class KLDivLoss(nn.Module):
+
+ def __init__(self,
+ temperature: float = 1.0,
+ reduction: str = 'mean',
+ loss_name: str = 'loss_kld'):
+ """Kullback-Leibler divergence Loss.
+
+
+
+ Args:
+ temperature (float, optional): Temperature param
+ reduction (str, optional): The method to reduce the loss into a
+ scalar. Default is "mean". Options are "none", "sum",
+ and "mean"
+ """
+
+ assert isinstance(temperature, (float, int)), \
+ 'Expected temperature to be' \
+ f'float or int, but got {temperature.__class__.__name__} instead'
+ assert temperature != 0., 'Temperature must not be zero'
+
+ assert reduction in ['mean', 'none', 'sum'], \
+ 'Reduction must be one of the options ("mean", ' \
+ f'"sum", "none"), but got {reduction}'
+
+ super().__init__()
+ self.temperature = temperature
+ self.reduction = reduction
+ self._loss_name = loss_name
+
+ def forward(self, input: torch.Tensor, target: torch.Tensor):
+ """Forward function. Calculate KL divergence Loss.
+
+ Args:
+ input (Tensor): Logit tensor,
+ the data type is float32 or float64.
+ The shape is (N, C) where N is batchsize and C is number of
+ channels.
+ If there more than 2 dimensions, shape is (N, C, D1, D2, ...
+ Dk), k>= 1
+ target (Tensor): Logit tensor,
+ the data type is float32 or float64.
+ input and target must be with the same shape.
+
+ Returns:
+ (Tensor): Reduced loss.
+ """
+ assert isinstance(input, torch.Tensor), 'Expected input to' \
+ f'be Tensor, but got {input.__class__.__name__} instead'
+ assert isinstance(target, torch.Tensor), 'Expected target to' \
+ f'be Tensor, but got {target.__class__.__name__} instead'
+
+ assert input.shape == target.shape, 'Input and target ' \
+ 'must have same shape,' \
+ f'but got shapes {input.shape} and {target.shape}'
+
+ input = F.softmax(input / self.temperature, dim=1)
+ target = F.softmax(target / self.temperature, dim=1)
+
+ loss = F.kl_div(input, target, reduction='none', log_target=False)
+ loss = loss * self.temperature**2
+
+ batch_size = input.shape[0]
+
+ if self.reduction == 'sum':
+ # Change view to calculate instance-wise sum
+ loss = loss.view(batch_size, -1)
+ return torch.sum(loss, dim=1)
+
+ elif self.reduction == 'mean':
+ # Change view to calculate instance-wise mean
+ loss = loss.view(batch_size, -1)
+ return torch.mean(loss, dim=1)
+
+ return loss
+
+ @property
+ def loss_name(self):
+ """Loss Name.
+
+ This function must be implemented and will return the name of this
+ loss function. This name will be used to combine different loss items
+ by simple sum operation. In addition, if you want this loss item to be
+ included into the backward graph, `loss_` must be the prefix of the
+ name.
+ Returns:
+ str: The name of this loss item.
+ """
+ return self._loss_name
diff --git a/mmseg/models/losses/silog_loss.py b/mmseg/models/losses/silog_loss.py
new file mode 100644
index 00000000000..ecc07aac424
--- /dev/null
+++ b/mmseg/models/losses/silog_loss.py
@@ -0,0 +1,122 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+from mmseg.registry import MODELS
+from .utils import weight_reduce_loss
+
+
+def silog_loss(pred: Tensor,
+ target: Tensor,
+ weight: Optional[Tensor] = None,
+ eps: float = 1e-4,
+ reduction: Union[str, None] = 'mean',
+ avg_factor: Optional[int] = None) -> Tensor:
+ """Computes the Scale-Invariant Logarithmic (SI-Log) loss between
+ prediction and target.
+
+ Args:
+ pred (Tensor): Predicted output.
+ target (Tensor): Ground truth.
+ weight (Optional[Tensor]): Optional weight to apply on the loss.
+ eps (float): Epsilon value to avoid division and log(0).
+ reduction (Union[str, None]): Specifies the reduction to apply to the
+ output: 'mean', 'sum' or None.
+ avg_factor (Optional[int]): Optional average factor for the loss.
+
+ Returns:
+ Tensor: The calculated SI-Log loss.
+ """
+ pred, target = pred.flatten(1), target.flatten(1)
+ valid_mask = (target > eps).detach().float()
+
+ diff_log = torch.log(target.clamp(min=eps)) - torch.log(
+ pred.clamp(min=eps))
+
+ valid_mask = (target > eps).detach() & (~torch.isnan(diff_log))
+ diff_log[~valid_mask] = 0.0
+ valid_mask = valid_mask.float()
+
+ diff_log_sq_mean = (diff_log.pow(2) * valid_mask).sum(
+ dim=1) / valid_mask.sum(dim=1).clamp(min=eps)
+ diff_log_mean = (diff_log * valid_mask).sum(dim=1) / valid_mask.sum(
+ dim=1).clamp(min=eps)
+
+ loss = torch.sqrt(diff_log_sq_mean - 0.5 * diff_log_mean.pow(2))
+
+ if weight is not None:
+ weight = weight.float()
+
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
+
+
+@MODELS.register_module()
+class SiLogLoss(nn.Module):
+ """Compute SiLog loss.
+
+ Args:
+ reduction (str, optional): The method used
+ to reduce the loss. Options are "none",
+ "mean" and "sum". Defaults to 'mean'.
+ loss_weight (float, optional): Weight of loss. Defaults to 1.0.
+ eps (float): Avoid dividing by zero. Defaults to 1e-3.
+ loss_name (str, optional): Name of the loss item. If you want this
+ loss item to be included into the backward graph, `loss_` must
+ be the prefix of the name. Defaults to 'loss_silog'.
+ """
+
+ def __init__(self,
+ reduction='mean',
+ loss_weight=1.0,
+ eps=1e-6,
+ loss_name='loss_silog'):
+ super().__init__()
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.eps = eps
+ self._loss_name = loss_name
+
+ def forward(
+ self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ ):
+
+ assert pred.shape == target.shape, 'the shapes of pred ' \
+ f'({pred.shape}) and target ({target.shape}) are mismatch'
+
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+
+ loss = self.loss_weight * silog_loss(
+ pred,
+ target,
+ weight,
+ eps=self.eps,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ )
+
+ return loss
+
+ @property
+ def loss_name(self):
+ """Loss Name.
+
+ This function must be implemented and will return the name of this
+ loss function. This name will be used to combine different loss items
+ by simple sum operation. In addition, if you want this loss item to be
+ included into the backward graph, `loss_` must be the prefix of the
+ name.
+ Returns:
+ str: The name of this loss item.
+ """
+ return self._loss_name
diff --git a/mmseg/models/losses/utils.py b/mmseg/models/losses/utils.py
index f74efcf35ce..04780347331 100644
--- a/mmseg/models/losses/utils.py
+++ b/mmseg/models/losses/utils.py
@@ -25,7 +25,7 @@ def get_class_weight(class_weight):
return class_weight
-def reduce_loss(loss, reduction):
+def reduce_loss(loss, reduction) -> torch.Tensor:
"""Reduce loss as specified.
Args:
@@ -45,7 +45,10 @@ def reduce_loss(loss, reduction):
return loss.sum()
-def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
+def weight_reduce_loss(loss,
+ weight=None,
+ reduction='mean',
+ avg_factor=None) -> torch.Tensor:
"""Apply element-wise weight and reduce loss.
Args:
diff --git a/mmseg/models/segmentors/__init__.py b/mmseg/models/segmentors/__init__.py
index fec0d52c3a4..59b012f4172 100644
--- a/mmseg/models/segmentors/__init__.py
+++ b/mmseg/models/segmentors/__init__.py
@@ -1,9 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base import BaseSegmentor
from .cascade_encoder_decoder import CascadeEncoderDecoder
+from .depth_estimator import DepthEstimator
from .encoder_decoder import EncoderDecoder
+from .multimodal_encoder_decoder import MultimodalEncoderDecoder
from .seg_tta import SegTTAModel
__all__ = [
- 'BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder', 'SegTTAModel'
+ 'BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder', 'SegTTAModel',
+ 'MultimodalEncoderDecoder', 'DepthEstimator'
]
diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py
index 25487de5ab8..17a0bb2b33e 100644
--- a/mmseg/models/segmentors/base.py
+++ b/mmseg/models/segmentors/base.py
@@ -187,6 +187,7 @@ def postprocess_result(self,
if C > 1:
i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True)
else:
+ i_seg_logits = i_seg_logits.sigmoid()
i_seg_pred = (i_seg_logits >
self.decode_head.threshold).to(i_seg_logits)
data_samples[i].set_data({
diff --git a/mmseg/models/segmentors/depth_estimator.py b/mmseg/models/segmentors/depth_estimator.py
new file mode 100644
index 00000000000..1020637e737
--- /dev/null
+++ b/mmseg/models/segmentors/depth_estimator.py
@@ -0,0 +1,392 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmengine.logging import print_log
+from mmengine.structures import PixelData
+from torch import Tensor
+
+from mmseg.registry import MODELS
+from mmseg.structures import SegDataSample
+from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig,
+ OptSampleList, SampleList, add_prefix)
+from ..utils import resize
+from .encoder_decoder import EncoderDecoder
+
+
+@MODELS.register_module()
+class DepthEstimator(EncoderDecoder):
+ """Encoder Decoder depth estimator.
+
+ EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
+ Note that auxiliary_head is only used for deep supervision during training,
+ which could be dumped during inference.
+
+ 1. The ``loss`` method is used to calculate the loss of model,
+ which includes two steps: (1) Extracts features to obtain the feature maps
+ (2) Call the decode head loss function to forward decode head model and
+ calculate losses.
+
+ .. code:: text
+
+ loss(): extract_feat() -> _decode_head_forward_train() -> _auxiliary_head_forward_train (optional)
+ _decode_head_forward_train(): decode_head.loss()
+ _auxiliary_head_forward_train(): auxiliary_head.loss (optional)
+
+ 2. The ``predict`` method is used to predict depth estimation results,
+ which includes two steps: (1) Run inference function to obtain the list of
+ depth (2) Call post-processing function to obtain list of
+ ``SegDataSample`` including ``pred_depth_map``.
+
+ .. code:: text
+
+ predict(): inference() -> postprocess_result()
+ inference(): whole_inference()/slide_inference()
+ whole_inference()/slide_inference(): encoder_decoder()
+ encoder_decoder(): extract_feat() -> decode_head.predict()
+
+ 3. The ``_forward`` method is used to output the tensor by running the model,
+ which includes two steps: (1) Extracts features to obtain the feature maps
+ (2)Call the decode head forward function to forward decode head model.
+
+ .. code:: text
+
+ _forward(): extract_feat() -> _decode_head.forward()
+
+ Args:
+
+ backbone (ConfigType): The config for the backnone of depth estimator.
+ decode_head (ConfigType): The config for the decode head of depth estimator.
+ neck (OptConfigType): The config for the neck of depth estimator.
+ Defaults to None.
+ auxiliary_head (OptConfigType): The config for the auxiliary head of
+ depth estimator. Defaults to None.
+ train_cfg (OptConfigType): The config for training. Defaults to None.
+ test_cfg (OptConfigType): The config for testing. Defaults to None.
+ data_preprocessor (dict, optional): The pre-process config of
+ :class:`BaseDataPreprocessor`.
+ pretrained (str, optional): The path for pretrained model.
+ Defaults to None.
+ init_cfg (dict, optional): The weight initialized config for
+ :class:`BaseModule`.
+ """ # noqa: E501
+
+ def __init__(self,
+ backbone: ConfigType,
+ decode_head: ConfigType,
+ neck: OptConfigType = None,
+ auxiliary_head: OptConfigType = None,
+ train_cfg: OptConfigType = None,
+ test_cfg: OptConfigType = None,
+ data_preprocessor: OptConfigType = None,
+ pretrained: Optional[str] = None,
+ init_cfg: OptMultiConfig = None):
+ super().__init__(
+ backbone=backbone,
+ decode_head=decode_head,
+ neck=neck,
+ auxiliary_head=auxiliary_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ data_preprocessor=data_preprocessor,
+ pretrained=pretrained,
+ init_cfg=init_cfg)
+
+ def extract_feat(self,
+ inputs: Tensor,
+ batch_img_metas: Optional[List[dict]] = None) -> Tensor:
+ """Extract features from images."""
+
+ if getattr(self.backbone, 'class_embed_select', False) and \
+ isinstance(batch_img_metas, list) and \
+ 'category_id' in batch_img_metas[0]:
+ cat_ids = [meta['category_id'] for meta in batch_img_metas]
+ cat_ids = torch.tensor(cat_ids).to(inputs.device)
+ inputs = (inputs, cat_ids)
+
+ x = self.backbone(inputs)
+ if self.with_neck:
+ x = self.neck(x)
+ return x
+
+ def encode_decode(self, inputs: Tensor,
+ batch_img_metas: List[dict]) -> Tensor:
+ """Encode images with backbone and decode into a depth map of the same
+ size as input."""
+ x = self.extract_feat(inputs, batch_img_metas)
+ depth = self.decode_head.predict(x, batch_img_metas, self.test_cfg)
+
+ return depth
+
+ def _decode_head_forward_train(self, inputs: List[Tensor],
+ data_samples: SampleList) -> dict:
+ """Run forward function and calculate loss for decode head in
+ training."""
+ losses = dict()
+ loss_decode = self.decode_head.loss(inputs, data_samples,
+ self.train_cfg)
+
+ losses.update(add_prefix(loss_decode, 'decode'))
+ return losses
+
+ def _auxiliary_head_forward_train(self, inputs: List[Tensor],
+ data_samples: SampleList) -> dict:
+ """Run forward function and calculate loss for auxiliary head in
+ training."""
+ losses = dict()
+ if isinstance(self.auxiliary_head, nn.ModuleList):
+ for idx, aux_head in enumerate(self.auxiliary_head):
+ loss_aux = aux_head.loss(inputs, data_samples, self.train_cfg)
+ losses.update(add_prefix(loss_aux, f'aux_{idx}'))
+ else:
+ loss_aux = self.auxiliary_head.loss(inputs, data_samples,
+ self.train_cfg)
+ losses.update(add_prefix(loss_aux, 'aux'))
+
+ return losses
+
+ def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
+ """Calculate losses from a batch of inputs and data samples.
+
+ Args:
+ inputs (Tensor): Input images.
+ data_samples (list[:obj:`SegDataSample`]): The seg data samples.
+ It usually includes information such as `metainfo` and
+ `gt_depth_map`.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ if data_samples is not None:
+ batch_img_metas = [
+ data_sample.metainfo for data_sample in data_samples
+ ]
+ else:
+ batch_img_metas = [
+ dict(
+ ori_shape=inputs.shape[2:],
+ img_shape=inputs.shape[2:],
+ pad_shape=inputs.shape[2:],
+ padding_size=[0, 0, 0, 0])
+ ] * inputs.shape[0]
+
+ x = self.extract_feat(inputs, batch_img_metas)
+
+ losses = dict()
+
+ loss_decode = self._decode_head_forward_train(x, data_samples)
+ losses.update(loss_decode)
+
+ if self.with_auxiliary_head:
+ loss_aux = self._auxiliary_head_forward_train(x, data_samples)
+ losses.update(loss_aux)
+
+ return losses
+
+ def predict(self,
+ inputs: Tensor,
+ data_samples: OptSampleList = None) -> SampleList:
+ """Predict results from a batch of inputs and data samples with post-
+ processing.
+
+ Args:
+ inputs (Tensor): Inputs with shape (N, C, H, W).
+ data_samples (List[:obj:`SegDataSample`], optional): The seg data
+ samples. It usually includes information such as `metainfo`
+ and `gt_depth_map`.
+
+ Returns:
+ list[:obj:`SegDataSample`]: Depth estimation results of the
+ input images. Each SegDataSample usually contain:
+
+ - ``pred_depth_max``(PixelData): Prediction of depth estimation.
+ """
+ if data_samples is not None:
+ batch_img_metas = [
+ data_sample.metainfo for data_sample in data_samples
+ ]
+ else:
+ batch_img_metas = [
+ dict(
+ ori_shape=inputs.shape[2:],
+ img_shape=inputs.shape[2:],
+ pad_shape=inputs.shape[2:],
+ padding_size=[0, 0, 0, 0])
+ ] * inputs.shape[0]
+
+ depth = self.inference(inputs, batch_img_metas)
+
+ return self.postprocess_result(depth, data_samples)
+
+ def _forward(self,
+ inputs: Tensor,
+ data_samples: OptSampleList = None) -> Tensor:
+ """Network forward process.
+
+ Args:
+ inputs (Tensor): Inputs with shape (N, C, H, W).
+ data_samples (List[:obj:`SegDataSample`]): The seg
+ data samples. It usually includes information such
+ as `metainfo` and `gt_depth_map`.
+
+ Returns:
+ Tensor: Forward output of model without any post-processes.
+ """
+ x = self.extract_feat(inputs)
+ return self.decode_head.forward(x)
+
+ def slide_flip_inference(self, inputs: Tensor,
+ batch_img_metas: List[dict]) -> Tensor:
+ """Inference by sliding-window with overlap and flip.
+
+ If h_crop > h_img or w_crop > w_img, the small patch will be used to
+ decode without padding.
+
+ Args:
+ inputs (tensor): the tensor should have a shape NxCxHxW,
+ which contains all images in the batch.
+ batch_img_metas (List[dict]): List of image metainfo where each may
+ also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
+ 'ori_shape', and 'pad_shape'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
+
+ Returns:
+ Tensor: The depth estimation results.
+ """
+
+ h_stride, w_stride = self.test_cfg.stride
+ h_crop, w_crop = self.test_cfg.crop_size
+ batch_size, _, h_img, w_img = inputs.size()
+ out_channels = self.out_channels
+ h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
+ w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
+ preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
+ count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img))
+ for h_idx in range(h_grids):
+ for w_idx in range(w_grids):
+ y1 = h_idx * h_stride
+ x1 = w_idx * w_stride
+ y2 = min(y1 + h_crop, h_img)
+ x2 = min(x1 + w_crop, w_img)
+ y1 = max(y2 - h_crop, 0)
+ x1 = max(x2 - w_crop, 0)
+ crop_img = inputs[:, :, y1:y2, x1:x2]
+ # change the image shape to patch shape
+ batch_img_metas[0]['img_shape'] = crop_img.shape[2:]
+ # the output of encode_decode is depth tensor map
+ # with shape [N, C, H, W]
+ crop_depth_map = self.encode_decode(crop_img, batch_img_metas)
+
+ # average out the original and flipped prediction
+ crop_depth_map_flip = self.encode_decode(
+ crop_img.flip(dims=(3, )), batch_img_metas)
+ crop_depth_map_flip = crop_depth_map_flip.flip(dims=(3, ))
+ crop_depth_map = (crop_depth_map + crop_depth_map_flip) / 2.0
+
+ preds += F.pad(crop_depth_map,
+ (int(x1), int(preds.shape[3] - x2), int(y1),
+ int(preds.shape[2] - y2)))
+
+ count_mat[:, :, y1:y2, x1:x2] += 1
+ assert (count_mat == 0).sum() == 0
+ depth = preds / count_mat
+
+ return depth
+
+ def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor:
+ """Inference with slide/whole style.
+
+ Args:
+ inputs (Tensor): The input image of shape (N, 3, H, W).
+ batch_img_metas (List[dict]): List of image metainfo where each may
+ also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
+ 'ori_shape', 'pad_shape', and 'padding_size'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
+
+ Returns:
+ Tensor: The depth estimation results.
+ """
+ assert self.test_cfg.get('mode', 'whole') in ['slide', 'whole',
+ 'slide_flip'], \
+ f'Only "slide", "slide_flip" or "whole" test mode are ' \
+ f'supported, but got {self.test_cfg["mode"]}.'
+ ori_shape = batch_img_metas[0]['ori_shape']
+ if not all(_['ori_shape'] == ori_shape for _ in batch_img_metas):
+ print_log(
+ 'Image shapes are different in the batch.',
+ logger='current',
+ level=logging.WARN)
+ if self.test_cfg.mode == 'slide':
+ depth_map = self.slide_inference(inputs, batch_img_metas)
+ if self.test_cfg.mode == 'slide_flip':
+ depth_map = self.slide_flip_inference(inputs, batch_img_metas)
+ else:
+ depth_map = self.whole_inference(inputs, batch_img_metas)
+
+ return depth_map
+
+ def postprocess_result(self,
+ depth: Tensor,
+ data_samples: OptSampleList = None) -> SampleList:
+ """ Convert results list to `SegDataSample`.
+ Args:
+ depth (Tensor): The depth estimation results.
+ data_samples (list[:obj:`SegDataSample`]): The seg data samples.
+ It usually includes information such as `metainfo` and
+ `gt_depth_map`. Default to None.
+ Returns:
+ list[:obj:`SegDataSample`]: Depth estomation results of the
+ input images. Each SegDataSample usually contain:
+
+ - ``pred_depth_map``(PixelData): Prediction of depth estimation.
+ """
+ batch_size, C, H, W = depth.shape
+
+ if data_samples is None:
+ data_samples = [SegDataSample() for _ in range(batch_size)]
+ only_prediction = True
+ else:
+ only_prediction = False
+
+ for i in range(batch_size):
+ if not only_prediction:
+ img_meta = data_samples[i].metainfo
+ # remove padding area
+ if 'img_padding_size' not in img_meta:
+ padding_size = img_meta.get('padding_size', [0] * 4)
+ else:
+ padding_size = img_meta['img_padding_size']
+ padding_left, padding_right, padding_top, padding_bottom =\
+ padding_size
+ # i_depth shape is 1, C, H, W after remove padding
+ i_depth = depth[i:i + 1, :, padding_top:H - padding_bottom,
+ padding_left:W - padding_right]
+
+ flip = img_meta.get('flip', None)
+ if flip:
+ flip_direction = img_meta.get('flip_direction', None)
+ assert flip_direction in ['horizontal', 'vertical']
+ if flip_direction == 'horizontal':
+ i_depth = i_depth.flip(dims=(3, ))
+ else:
+ i_depth = i_depth.flip(dims=(2, ))
+
+ # resize as original shape
+ i_depth = resize(
+ i_depth,
+ size=img_meta['ori_shape'],
+ mode='bilinear',
+ align_corners=self.align_corners,
+ warning=False).squeeze(0)
+ else:
+ i_depth = depth[i]
+
+ data_samples[i].set_data(
+ {'pred_depth_map': PixelData(**{'data': i_depth})})
+
+ return data_samples
diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py
index 0a8db3ec7de..fa4050e0b73 100644
--- a/mmseg/models/segmentors/encoder_decoder.py
+++ b/mmseg/models/segmentors/encoder_decoder.py
@@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import logging
from typing import List, Optional
import torch.nn as nn
import torch.nn.functional as F
+from mmengine.logging import print_log
from torch import Tensor
from mmseg.registry import MODELS
@@ -33,7 +35,7 @@ class EncoderDecoder(BaseSegmentor):
2. The ``predict`` method is used to predict segmentation results,
which includes two steps: (1) Run inference function to obtain the list of
seg_logits (2) Call post-processing function to obtain list of
- ``SegDataSampel`` including ``pred_sem_seg`` and ``seg_logits``.
+ ``SegDataSample`` including ``pred_sem_seg`` and ``seg_logits``.
.. code:: text
@@ -260,10 +262,10 @@ def slide_inference(self, inputs: Tensor,
h_stride, w_stride = self.test_cfg.stride
h_crop, w_crop = self.test_cfg.crop_size
batch_size, _, h_img, w_img = inputs.size()
- num_classes = self.num_classes
+ out_channels = self.out_channels
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
- preds = inputs.new_zeros((batch_size, num_classes, h_img, w_img))
+ preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img))
for h_idx in range(h_grids):
for w_idx in range(w_grids):
@@ -326,10 +328,15 @@ def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor:
Tensor: The segmentation results, seg_logits from model of each
input image.
"""
-
- assert self.test_cfg.mode in ['slide', 'whole']
+ assert self.test_cfg.get('mode', 'whole') in ['slide', 'whole'], \
+ f'Only "slide" or "whole" test mode are supported, but got ' \
+ f'{self.test_cfg["mode"]}.'
ori_shape = batch_img_metas[0]['ori_shape']
- assert all(_['ori_shape'] == ori_shape for _ in batch_img_metas)
+ if not all(_['ori_shape'] == ori_shape for _ in batch_img_metas):
+ print_log(
+ 'Image shapes are different in the batch.',
+ logger='current',
+ level=logging.WARN)
if self.test_cfg.mode == 'slide':
seg_logit = self.slide_inference(inputs, batch_img_metas)
else:
diff --git a/mmseg/models/segmentors/multimodal_encoder_decoder.py b/mmseg/models/segmentors/multimodal_encoder_decoder.py
new file mode 100644
index 00000000000..75aa8b9b176
--- /dev/null
+++ b/mmseg/models/segmentors/multimodal_encoder_decoder.py
@@ -0,0 +1,350 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional
+
+import torch.nn.functional as F
+from torch import Tensor
+
+from mmseg.registry import MODELS
+from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig,
+ OptSampleList, SampleList, add_prefix)
+from .base import BaseSegmentor
+
+
+@MODELS.register_module()
+class MultimodalEncoderDecoder(BaseSegmentor):
+ """Multimodal Encoder-Decoder segmentors.
+
+ Multimodal segmentation architecture is used for open-vocabulary
+ semantic segmentation with combining the visual and language
+ pretrain models. It consists of a image_encoder (backbone) to extract
+ visual feature, a text encoder to extract text feature, and a decode
+ head to generate semantic maps.
+ Note that the deep supervision during training is implemented in decode head.
+
+ 1. The ``loss`` method is used to calculate the loss of model,
+ which includes two steps: (1) Extracts features to obtain the feature maps
+ (2) Call the decode head loss function to forward decode head model and
+ calculate losses.
+
+ .. code:: text
+
+ loss(): extract_feat() -> _decode_head_forward_train()
+ _decode_head_forward_train(): decode_head.loss()
+
+ 2. The ``predict`` method is used to predict segmentation results,
+ which includes two steps: (1) Run inference function to obtain the list of
+ seg_logits (2) Call post-processing function to obtain list of
+ ``SegDataSampel`` including ``pred_sem_seg`` and ``seg_logits``.
+
+ .. code:: text
+
+ predict(): inference() -> postprocess_result()
+ inference(): whole_inference()/slide_inference()
+ whole_inference()/slide_inference(): encoder_decoder()
+ encoder_decoder(): extract_feat() -> decode_head.predict()
+
+ 3. The ``_forward`` method is used to output the tensor by running the model,
+ which includes two steps: (1) Extracts features to obtain the feature maps
+ (2)Call the decode head forward function to forward decode head model.
+
+ .. code:: text
+
+ _forward(): extract_feat() -> _decode_head.forward()
+
+ Args:
+
+ image_encoder (ConfigType): The config for the visual encoder of segmentor.
+ text_encoder ((ConfigType): The config for the text encoder of segmentor.
+ decode_head (ConfigType): The config for the decode head of segmentor.
+ train_cfg (OptConfigType): The config for training. Defaults to None.
+ test_cfg (OptConfigType): The config for testing. Defaults to None.
+ data_preprocessor (dict, optional): The pre-process config of
+ :class:`BaseDataPreprocessor`.
+ pretrained (str, optional): The path for pretrained model.
+ Defaults to None.
+ asymetric_input (bool): whether to use different size of input for image encoder
+ and decode head. Defaults to False.
+ encoder_resolution (float): resize scale of input images for image encoder.
+ Defaults to None.
+ init_cfg (dict, optional): The weight initialized config for
+ :class:`BaseModule`.
+ """ # noqa: E501
+
+ def __init__(self,
+ image_encoder: ConfigType,
+ text_encoder: ConfigType,
+ decode_head: ConfigType,
+ train_cfg: OptConfigType = None,
+ test_cfg: OptConfigType = None,
+ data_preprocessor: OptConfigType = None,
+ pretrained: Optional[str] = None,
+ asymetric_input: bool = True,
+ encoder_resolution: float = None,
+ init_cfg: OptMultiConfig = None):
+ super().__init__(
+ data_preprocessor=data_preprocessor, init_cfg=init_cfg)
+ if pretrained is not None:
+ image_encoder.init_cfg = dict(
+ type='Pretrained_Part', checkpoint=pretrained)
+ text_encoder.init_cfg = dict(
+ type='Pretrained_Part', checkpoint=pretrained)
+ decode_head.init_cfg = dict(
+ type='Pretrained_Part', checkpoint=pretrained)
+
+ if asymetric_input:
+ assert encoder_resolution is not None, \
+ 'if asymetric_input set True, ' \
+ 'clip_resolution must be a certain value'
+ self.asymetric_input = asymetric_input
+ self.encoder_resolution = encoder_resolution
+ self.image_encoder = MODELS.build(image_encoder)
+ self.text_encoder = MODELS.build(text_encoder)
+ self._init_decode_head(decode_head)
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ assert self.with_decode_head
+
+ def _init_decode_head(self, decode_head: ConfigType) -> None:
+ """Initialize ``decode_head``"""
+ self.decode_head = MODELS.build(decode_head)
+ self.align_corners = self.decode_head.align_corners
+ self.num_classes = self.decode_head.num_classes
+ self.out_channels = self.decode_head.out_channels
+
+ def extract_feat(self, inputs: Tensor) -> List[Tensor]:
+ """Extract visual features from images."""
+ x = self.image_encoder(inputs)
+ return x
+
+ def encode_decode(self, inputs: Tensor,
+ batch_img_metas: List[dict]) -> Tensor:
+ """Encode the name of classes with text_encoder and encode images with
+ image_encoder.
+
+ Then decode the class embedding and visual feature into a semantic
+ segmentation map of the same size as input.
+ """
+ classifier_embeds = self.text_encoder()
+ clip_inputs = inputs
+ if self.asymetric_input:
+ clip_inputs = F.interpolate(
+ inputs, scale_factor=self.encoder_resolution, mode='bilinear')
+ x = self.image_encoder(clip_inputs)
+ seg_logits = self.decode_head.predict([inputs, x, classifier_embeds],
+ batch_img_metas, self.test_cfg)
+
+ return seg_logits
+
+ def _decode_head_forward_train(self, inputs: List[Tensor],
+ data_samples: SampleList) -> dict:
+ """Run forward function and calculate loss for decode head in
+ training."""
+ losses = dict()
+ loss_decode = self.decode_head.loss(inputs, data_samples,
+ self.train_cfg)
+
+ losses.update(add_prefix(loss_decode, 'decode'))
+ return losses
+
+ def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
+ """Calculate losses from a batch of inputs and data samples.
+
+ Args:
+ inputs (Tensor): Input images.
+ data_samples (list[:obj:`SegDataSample`]): The seg data samples.
+ It usually includes information such as `metainfo` and
+ `gt_sem_seg`.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ classifier_embeds = self.text_encoder()
+ clip_inputs = inputs
+ if self.asymetric_input:
+ clip_inputs = F.interpolate(
+ inputs, scale_factor=self.encoder_resolution, mode='bilinear')
+ x = self.image_encoder(clip_inputs)
+
+ losses = dict()
+
+ loss_decode = self._decode_head_forward_train(
+ [inputs, x, classifier_embeds], data_samples)
+ losses.update(loss_decode)
+
+ return losses
+
+ def predict(self,
+ inputs: Tensor,
+ data_samples: OptSampleList = None) -> SampleList:
+ """Predict results from a batch of inputs and data samples with post-
+ processing.
+
+ Args:
+ inputs (Tensor): Inputs with shape (N, C, H, W).
+ data_samples (List[:obj:`SegDataSample`], optional): The seg data
+ samples. It usually includes information such as `metainfo`
+ and `gt_sem_seg`.
+
+ Returns:
+ list[:obj:`SegDataSample`]: Segmentation results of the
+ input images. Each SegDataSample usually contain:
+
+ - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.
+ - ``seg_logits``(PixelData): Predicted logits of semantic
+ segmentation before normalization.
+ """
+ if data_samples is not None:
+ batch_img_metas = [
+ data_sample.metainfo for data_sample in data_samples
+ ]
+ else:
+ batch_img_metas = [
+ dict(
+ ori_shape=inputs.shape[2:],
+ img_shape=inputs.shape[2:],
+ pad_shape=inputs.shape[2:],
+ padding_size=[0, 0, 0, 0])
+ ] * inputs.shape[0]
+
+ seg_logits = self.inference(inputs, batch_img_metas)
+
+ return self.postprocess_result(seg_logits, data_samples)
+
+ def _forward(self,
+ inputs: Tensor,
+ data_samples: OptSampleList = None) -> Tensor:
+ """Network forward process.
+
+ Args:
+ inputs (Tensor): Inputs with shape (N, C, H, W).
+ data_samples (List[:obj:`SegDataSample`]): The seg
+ data samples. It usually includes information such
+ as `metainfo` and `gt_sem_seg`.
+
+ Returns:
+ Tensor: Forward output of model without any post-processes.
+ """
+ x = self.extract_feat(inputs)
+ return self.decode_head.forward(x)
+
+ def slide_inference(self, inputs: Tensor,
+ batch_img_metas: List[dict]) -> Tensor:
+ """Inference by sliding-window with overlap.
+
+ If h_crop > h_img or w_crop > w_img, the small patch will be used to
+ decode without padding.
+
+ Args:
+ inputs (tensor): the tensor should have a shape NxCxHxW,
+ which contains all images in the batch.
+ batch_img_metas (List[dict]): List of image metainfo where each may
+ also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
+ 'ori_shape', and 'pad_shape'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
+
+ Returns:
+ Tensor: The segmentation results, seg_logits from model of each
+ input image.
+ """
+
+ h_stride, w_stride = self.test_cfg.stride
+ h_crop, w_crop = self.test_cfg.crop_size
+ batch_size, _, h_img, w_img = inputs.size()
+ out_channels = self.out_channels
+ h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
+ w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
+ preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
+ count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img))
+ for h_idx in range(h_grids):
+ for w_idx in range(w_grids):
+ y1 = h_idx * h_stride
+ x1 = w_idx * w_stride
+ y2 = min(y1 + h_crop, h_img)
+ x2 = min(x1 + w_crop, w_img)
+ y1 = max(y2 - h_crop, 0)
+ x1 = max(x2 - w_crop, 0)
+ crop_img = inputs[:, :, y1:y2, x1:x2]
+ # change the image shape to patch shape
+ batch_img_metas[0]['img_shape'] = crop_img.shape[2:]
+ # the output of encode_decode is seg logits tensor map
+ # with shape [N, C, H, W]
+ crop_seg_logit = self.encode_decode(crop_img, batch_img_metas)
+ preds += F.pad(crop_seg_logit,
+ (int(x1), int(preds.shape[3] - x2), int(y1),
+ int(preds.shape[2] - y2)))
+
+ count_mat[:, :, y1:y2, x1:x2] += 1
+ assert (count_mat == 0).sum() == 0
+ seg_logits = preds / count_mat
+
+ return seg_logits
+
+ def whole_inference(self, inputs: Tensor,
+ batch_img_metas: List[dict]) -> Tensor:
+ """Inference with full image.
+
+ Args:
+ inputs (Tensor): The tensor should have a shape NxCxHxW, which
+ contains all images in the batch.
+ batch_img_metas (List[dict]): List of image metainfo where each may
+ also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
+ 'ori_shape', and 'pad_shape'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
+
+ Returns:
+ Tensor: The segmentation results, seg_logits from model of each
+ input image.
+ """
+
+ seg_logits = self.encode_decode(inputs, batch_img_metas)
+
+ return seg_logits
+
+ def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor:
+ """Inference with slide/whole style.
+
+ Args:
+ inputs (Tensor): The input image of shape (N, 3, H, W).
+ batch_img_metas (List[dict]): List of image metainfo where each may
+ also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
+ 'ori_shape', 'pad_shape', and 'padding_size'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
+
+ Returns:
+ Tensor: The segmentation results, seg_logits from model of each
+ input image.
+ """
+
+ assert self.test_cfg.mode in ['slide', 'whole']
+ ori_shape = batch_img_metas[0]['ori_shape']
+ assert all(_['ori_shape'] == ori_shape for _ in batch_img_metas)
+ if self.test_cfg.mode == 'slide':
+ seg_logit = self.slide_inference(inputs, batch_img_metas)
+ else:
+ seg_logit = self.whole_inference(inputs, batch_img_metas)
+
+ return seg_logit
+
+ def aug_test(self, inputs, batch_img_metas, rescale=True):
+ """Test with augmentations.
+
+ Only rescale=True is supported.
+ """
+ # aug_test rescale all imgs back to ori_shape for now
+ assert rescale
+ # to save memory, we get augmented seg logit inplace
+ seg_logit = self.inference(inputs[0], batch_img_metas[0], rescale)
+ for i in range(1, len(inputs)):
+ cur_seg_logit = self.inference(inputs[i], batch_img_metas[i],
+ rescale)
+ seg_logit += cur_seg_logit
+ seg_logit /= len(inputs)
+ seg_pred = seg_logit.argmax(dim=1)
+ # unravel batch dim
+ seg_pred = list(seg_pred)
+ return seg_pred
diff --git a/mmseg/models/segmentors/seg_tta.py b/mmseg/models/segmentors/seg_tta.py
index eacb6c00a9a..63ef61d223a 100644
--- a/mmseg/models/segmentors/seg_tta.py
+++ b/mmseg/models/segmentors/seg_tta.py
@@ -6,7 +6,6 @@
from mmengine.structures import PixelData
from mmseg.registry import MODELS
-from mmseg.structures import SegDataSample
from mmseg.utils import SampleList
@@ -39,10 +38,10 @@ def merge_preds(self, data_samples_list: List[SampleList]) -> SampleList:
).to(logits).squeeze(1)
else:
seg_pred = logits.argmax(dim=0)
- data_sample = SegDataSample(
- **{
- 'pred_sem_seg': PixelData(data=seg_pred),
- 'gt_sem_seg': data_samples[0].gt_sem_seg
- })
+ data_sample.set_data({'pred_sem_seg': PixelData(data=seg_pred)})
+ if hasattr(data_samples[0], 'gt_sem_seg'):
+ data_sample.set_data(
+ {'gt_sem_seg': data_samples[0].gt_sem_seg})
+ data_sample.set_metainfo({'img_path': data_samples[0].img_path})
predictions.append(data_sample)
return predictions
diff --git a/mmseg/models/text_encoder/__init__.py b/mmseg/models/text_encoder/__init__.py
new file mode 100644
index 00000000000..199856d9d79
--- /dev/null
+++ b/mmseg/models/text_encoder/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .clip_text_encoder import CLIPTextEncoder
+
+__all__ = ['CLIPTextEncoder']
diff --git a/mmseg/models/text_encoder/clip_text_encoder.py b/mmseg/models/text_encoder/clip_text_encoder.py
new file mode 100644
index 00000000000..1a18b86395e
--- /dev/null
+++ b/mmseg/models/text_encoder/clip_text_encoder.py
@@ -0,0 +1,229 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List
+
+import numpy as np
+import torch
+import torch.nn as nn
+from mmcv.cnn import build_norm_layer
+from mmcv.cnn.bricks.transformer import BaseTransformerLayer
+from mmengine.model import BaseModule, ModuleList
+from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
+from torch.nn import functional as F
+
+from mmseg.registry import MODELS
+from mmseg.utils import get_classes, get_predefined_templates, tokenizer
+
+
+@MODELS.register_module()
+class CLIPTextEncoder(BaseModule):
+ """A text encoder with transformer architecture to encode the label text.
+
+ Modified from https://github.com/MendelXu/SAN/blob/main/san/model/clip_utils/classifier.py # noqa:E501
+ Copyright (c) 2023 MendelXu.
+ Licensed under the MIT License
+
+ Args:
+ dataset_name: (str|None): The name of the dataset to which
+ the data belongs.
+ vocabulary: (List[str]|None): The list of class names. Default: None.
+ templates: (List[str]|None): The prompt template used for labels.
+ Default: None.
+ total_vocab_size: (int): Number of all words used by the pre-trained
+ model. Default: 49408 (CLIP).
+ context_length: (int): The max length of prompt text.
+ Default: 77 (CLIP).
+ embed_dims: (int): Width of transformer model. Default: 512.
+ num_layers: (int): Depth of transformer. Default: 12,
+ num_heads: (int): Number of attention heads in transformer.
+ Default: 8,
+ mlp_ratio: (int) Ratio of mlp hidden dim to embedding dim in
+ transformer. Default: 4,
+ output_dims: (int) Dim of output text embeddings. Default: 512,
+ cache_feature: (bool) Whether to save class embeddings in cache.
+ Default: True,
+ cat_bg: (bool) Whether to add background embedding. Default: True.
+ norm_cfg (dict|None): Config for norm layer. Default: dict(type='LN')
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ """
+
+ def __init__(self,
+ dataset_name: str = None,
+ vocabulary: List[str] = None,
+ templates: str = 'vild',
+ total_vocab_size: int = 49408,
+ context_length: int = 77,
+ embed_dims: int = 512,
+ num_layers: int = 12,
+ num_heads: int = 8,
+ mlp_ratio: int = 4,
+ output_dims: int = 512,
+ cache_feature: bool = True,
+ cat_bg: bool = True,
+ norm_cfg: dict = dict(type='LN'),
+ init_cfg: dict = None):
+ super().__init__(init_cfg)
+ if isinstance(templates, List):
+ self.templates = templates
+ else:
+ self.templates = get_predefined_templates(templates)
+
+ assert dataset_name is not None or vocabulary is not None, \
+ "text_encoder required either 'dataset_name' or 'vocabulary'"
+ assert dataset_name is None or vocabulary is None, \
+ "there is conflict between 'dataset_name' and 'vocabulary'"
+ self.dataset_name = dataset_name
+ self.vocabulary = vocabulary
+ self.num_pos = context_length
+ self.token_embedding = nn.Embedding(total_vocab_size, embed_dims)
+ self.positional_embedding = nn.Parameter(
+ torch.empty(context_length, embed_dims))
+ self.text_projection = nn.Parameter(
+ torch.empty(embed_dims, output_dims))
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+ self.transformer = ModuleList()
+ self.register_buffer(
+ 'attn_mask', self.build_attention_mask(), persistent=False)
+ for i in range(num_layers):
+ self.transformer.append(
+ BaseTransformerLayer(
+ attn_cfgs=dict(
+ type='MultiheadAttention',
+ embed_dims=embed_dims,
+ num_heads=num_heads,
+ batch_first=False,
+ bias=True),
+ ffn_cfgs=dict(
+ type='FFN',
+ embed_dims=embed_dims,
+ feedforward_channels=mlp_ratio * embed_dims,
+ act_cfg=dict(type='QuickGELU')),
+ operation_order=('norm', 'self_attn', 'norm', 'ffn')))
+ self.ln_final = build_norm_layer(
+ norm_cfg, embed_dims, postfix='_final')[1]
+
+ self.cache_feature = cache_feature
+ if self.cache_feature:
+ self.cache = {}
+
+ self._freeze()
+
+ self.cat_bg = cat_bg
+ if self.cat_bg:
+ self.bg_embed = nn.Parameter(
+ torch.randn(1, self.text_projection.shape[1]))
+
+ @property
+ def ln_final(self):
+ return getattr(self, self.final_name)
+
+ def build_attention_mask(self):
+ """lazily create causal attention mask, with full attention between the
+ tokens.
+
+ pytorch uses additive attention mask; fill with -inf
+ """
+ mask = torch.empty(self.num_pos, self.num_pos)
+ mask.fill_(float('-inf'))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ def _freeze(self):
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def init_weights(self):
+ if self.cat_bg:
+ nn.init.normal_(
+ self.bg_embed,
+ std=self.bg_embed.shape[1]**-0.5,
+ )
+ if isinstance(self.init_cfg, dict) and \
+ self.init_cfg.get('type') == 'Pretrained_Part':
+ checkpoint = CheckpointLoader.load_checkpoint(
+ self.init_cfg['checkpoint'], logger=None, map_location='cpu')
+
+ state_dict = checkpoint.copy()
+ para_prefix = 'text_encoder'
+ prefix_len = len(para_prefix) + 1
+ for k, v in checkpoint.items():
+ state_dict.pop(k)
+ if para_prefix in k:
+ state_dict[k[prefix_len:]] = v
+
+ load_state_dict(self, state_dict, strict=False, logger=None)
+
+ else:
+ super().init_weights()
+
+ @torch.no_grad()
+ def encode_text(self, text, normalize=False):
+ """encode class token."""
+
+ embed_device = self.token_embedding.weight.device
+ x = self.token_embedding(
+ text.to(embed_device)) # [batch_size, n_ctx, d_model]
+ x = x + self.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ for block in self.transformer:
+ x = block(query=x, attn_masks=self.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding
+ # (eot_token is the highest number in each sequence)
+ x = x[torch.arange(x.shape[0]),
+ text.argmax(dim=-1)] @ self.text_projection
+ return F.normalize(x, dim=-1) if normalize else x
+
+ def template_encode(self, vocabulary):
+ """Prompt engineering."""
+ text_embed_bucket = []
+ for template in self.templates:
+ text_inputs = tokenizer.tokenize(
+ [template.format(noun) for noun in vocabulary])
+ text_embed = self.encode_text(text_inputs, normalize=True)
+ text_embed_bucket.append(text_embed)
+ text_embed = torch.stack(text_embed_bucket).mean(dim=0)
+ text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
+ return text_embed
+
+ def forward(self):
+ """Forward function."""
+ if self.dataset_name is None: # encoding vocabulary directly
+ class_names = self.vocabulary
+ if self.cache_feature:
+ new_classes = [
+ word for word in class_names if word not in self.cache
+ ]
+ if len(new_classes) > 0:
+ class_embeds = self.template_encode(new_classes)
+ self.cache.update(dict(zip(new_classes, class_embeds)))
+ class_embeds = torch.stack(
+ [self.cache[word] for word in class_names])
+ else:
+ class_embeds = self.template_encode(class_names)
+
+ else: # encoding the classes of the dataset
+ class_names = get_classes(self.dataset_name)
+ if class_names[0] == 'background':
+ class_names = class_names[1:]
+ if self.cache_feature:
+ if self.dataset_name not in self.cache:
+ class_embeds = self.template_encode(class_names)
+ self.cache[self.dataset_name] = class_embeds
+ else:
+ class_embeds = self.cache[self.dataset_name]
+ else:
+ class_embeds = self.template_encode(class_names)
+
+ if self.cat_bg:
+ class_embeds = torch.cat([class_embeds, self.bg_embed])
+ class_embeds = F.normalize(class_embeds, p=2, dim=-1)
+ return self.logit_scale.exp() * class_embeds
+
+
+@MODELS.register_module()
+class QuickGELU(nn.Module):
+ # From https://github.com/openai/CLIP/blob/main/clip/model.py
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
diff --git a/mmseg/models/utils/__init__.py b/mmseg/models/utils/__init__.py
index fc142f16fc9..c0751b17c02 100644
--- a/mmseg/models/utils/__init__.py
+++ b/mmseg/models/utils/__init__.py
@@ -4,6 +4,7 @@
from .encoding import Encoding
from .inverted_residual import InvertedResidual, InvertedResidualV3
from .make_divisible import make_divisible
+from .point_sample import get_uncertain_point_coords_with_randomness
from .ppm import DAPPM, PAPPM
from .res_layer import ResLayer
from .se_layer import SELayer
@@ -11,11 +12,16 @@
from .shape_convert import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc,
nlc_to_nchw)
from .up_conv_block import UpConvBlock
+
+# isort: off
from .wrappers import Upsample, resize
+from .san_layers import MLP, LayerNorm2d, cross_attn_layer
__all__ = [
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed',
'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc', 'Encoding',
- 'Upsample', 'resize', 'DAPPM', 'PAPPM', 'BasicBlock', 'Bottleneck'
+ 'Upsample', 'resize', 'DAPPM', 'PAPPM', 'BasicBlock', 'Bottleneck',
+ 'cross_attn_layer', 'LayerNorm2d', 'MLP',
+ 'get_uncertain_point_coords_with_randomness'
]
diff --git a/mmseg/models/utils/point_sample.py b/mmseg/models/utils/point_sample.py
new file mode 100644
index 00000000000..1afc957f3da
--- /dev/null
+++ b/mmseg/models/utils/point_sample.py
@@ -0,0 +1,88 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmcv.ops import point_sample
+from torch import Tensor
+
+
+def get_uncertainty(mask_preds: Tensor, labels: Tensor) -> Tensor:
+ """Estimate uncertainty based on pred logits.
+
+ We estimate uncertainty as L1 distance between 0.0 and the logits
+ prediction in 'mask_preds' for the foreground class in `classes`.
+
+ Args:
+ mask_preds (Tensor): mask predication logits, shape (num_rois,
+ num_classes, mask_height, mask_width).
+
+ labels (Tensor): Either predicted or ground truth label for
+ each predicted mask, of length num_rois.
+
+ Returns:
+ scores (Tensor): Uncertainty scores with the most uncertain
+ locations having the highest uncertainty score,
+ shape (num_rois, 1, mask_height, mask_width)
+ """
+ if mask_preds.shape[1] == 1:
+ gt_class_logits = mask_preds.clone()
+ else:
+ inds = torch.arange(mask_preds.shape[0], device=mask_preds.device)
+ gt_class_logits = mask_preds[inds, labels].unsqueeze(1)
+ return -torch.abs(gt_class_logits)
+
+
+def get_uncertain_point_coords_with_randomness(
+ mask_preds: Tensor, labels: Tensor, num_points: int,
+ oversample_ratio: float, importance_sample_ratio: float) -> Tensor:
+ """Get ``num_points`` most uncertain points with random points during
+ train.
+
+ Sample points in [0, 1] x [0, 1] coordinate space based on their
+ uncertainty. The uncertainties are calculated for each point using
+ 'get_uncertainty()' function that takes point's logit prediction as
+ input.
+
+ Args:
+ mask_preds (Tensor): A tensor of shape (num_rois, num_classes,
+ mask_height, mask_width) for class-specific or class-agnostic
+ prediction.
+ labels (Tensor): The ground truth class for each instance.
+ num_points (int): The number of points to sample.
+ oversample_ratio (float): Oversampling parameter.
+ importance_sample_ratio (float): Ratio of points that are sampled
+ via importnace sampling.
+
+ Returns:
+ point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
+ that contains the coordinates sampled points.
+ """
+ assert oversample_ratio >= 1
+ assert 0 <= importance_sample_ratio <= 1
+ batch_size = mask_preds.shape[0]
+ num_sampled = int(num_points * oversample_ratio)
+ point_coords = torch.rand(
+ batch_size, num_sampled, 2, device=mask_preds.device)
+ point_logits = point_sample(mask_preds, point_coords)
+ # It is crucial to calculate uncertainty based on the sampled
+ # prediction value for the points. Calculating uncertainties of the
+ # coarse predictions first and sampling them for points leads to
+ # incorrect results. To illustrate this: assume uncertainty func(
+ # logits)=-abs(logits), a sampled point between two coarse
+ # predictions with -1 and 1 logits has 0 logits, and therefore 0
+ # uncertainty value. However, if we calculate uncertainties for the
+ # coarse predictions first, both will have -1 uncertainty,
+ # and sampled point will get -1 uncertainty.
+ point_uncertainties = get_uncertainty(point_logits, labels)
+ num_uncertain_points = int(importance_sample_ratio * num_points)
+ num_random_points = num_points - num_uncertain_points
+ idx = torch.topk(
+ point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
+ shift = num_sampled * torch.arange(
+ batch_size, dtype=torch.long, device=mask_preds.device)
+ idx += shift[:, None]
+ point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
+ batch_size, num_uncertain_points, 2)
+ if num_random_points > 0:
+ rand_roi_coords = torch.rand(
+ batch_size, num_random_points, 2, device=mask_preds.device)
+ point_coords = torch.cat((point_coords, rand_roi_coords), dim=1)
+ return point_coords
diff --git a/mmseg/models/utils/san_layers.py b/mmseg/models/utils/san_layers.py
new file mode 100644
index 00000000000..2267686daf6
--- /dev/null
+++ b/mmseg/models/utils/san_layers.py
@@ -0,0 +1,418 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Modified from https://github.com/MendelXu/SAN/blob/main/san/model/attn_helper.py # noqa: E501
+# Copyright (c) 2023 MendelXu.
+# Licensed under the MIT License
+
+import warnings
+from typing import Optional
+
+import torch
+from mmcv.cnn.bricks.transformer import BaseTransformerLayer
+from torch import Tensor, nn
+from torch.nn import functional as F
+
+
+def cross_attn_with_self_bias(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ embed_dim_to_check: int,
+ num_heads: int,
+ in_proj_weight: Tensor,
+ in_proj_bias: Tensor,
+ bias_k: Optional[Tensor],
+ bias_v: Optional[Tensor],
+ add_zero_attn: bool,
+ dropout_p: float,
+ out_proj_weight: Tensor,
+ out_proj_bias: Tensor,
+ training: bool = True,
+ key_padding_mask: Optional[Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[Tensor] = None,
+ use_separate_proj_weight: bool = False,
+ q_proj_weight: Optional[Tensor] = None,
+ k_proj_weight: Optional[Tensor] = None,
+ v_proj_weight: Optional[Tensor] = None,
+ static_k: Optional[Tensor] = None,
+ static_v: Optional[Tensor] = None,
+):
+ """Forward function of multi-head attention. Modified from
+ multi_head_attention_forward in
+ https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py.
+
+ Args:
+ query, key, value: map a query and a set of key-value pairs to an output.
+ See "Attention Is All You Need" for more details.
+ embed_dim_to_check: total dimension of the model.
+ num_heads: parallel attention heads.
+ in_proj_weight, in_proj_bias: input projection weight and bias.
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
+ add_zero_attn: add a new batch of zeros to the key and
+ value sequences at dim=1.
+ dropout_p: probability of an element to be zeroed.
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
+ training: apply dropout if is ``True``.
+ key_padding_mask: if provided, specified padding elements in the key will
+ be ignored by the attention. This is an binary mask. When the value is True,
+ the corresponding value on the attention layer will be filled with -inf.
+ need_weights: output attn_output_weights.
+ Default: `True`
+ Note: `needs_weight` defaults to `True`, but should be set to `False`
+ For best performance when attention weights are not needed.
+ *Setting needs_weights to `True`
+ leads to a significant performance degradation.*
+ attn_mask: 2D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
+ and value in different forms. If false, in_proj_weight will be used, which is
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
+ static_k, static_v: static key and value used for attention operators.
+ """ # noqa: E501
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == embed_dim_to_check
+ # allow MHA to have different sizes for the feature dimension
+ assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
+
+ head_dim = embed_dim // num_heads
+ assert head_dim * num_heads == embed_dim, \
+ 'embed_dim must be divisible by num_heads'
+ scaling = float(head_dim)**-0.5
+
+ if not use_separate_proj_weight:
+ if (query is key or torch.equal(
+ query, key)) and (key is value or torch.equal(key, value)):
+ # self-attention
+ raise NotImplementedError('self-attention is not implemented')
+
+ elif key is value or torch.equal(key, value):
+ # encoder-decoder attention
+ # This is inline in_proj function
+ # with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = 0
+ _end = embed_dim
+ _w = in_proj_weight[_start:_end, :]
+ if _b is not None:
+ _b = _b[_start:_end]
+ q = F.linear(query, _w, _b)
+
+ if key is None:
+ assert value is None
+ k = None
+ v = None
+ q_k = None
+ q_v = None
+ else:
+ # This is inline in_proj function with
+ # in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = embed_dim
+ _end = None
+ _w = in_proj_weight[_start:, :]
+ if _b is not None:
+ _b = _b[_start:]
+ k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
+ q_k, q_v = F.linear(query, _w, _b).chunk(2, dim=-1)
+ else:
+ # This is inline in_proj function with
+ # in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = 0
+ _end = embed_dim
+ _w = in_proj_weight[_start:_end, :]
+ if _b is not None:
+ _b = _b[_start:_end]
+ q = F.linear(query, _w, _b)
+
+ # This is inline in_proj function with
+ # in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = embed_dim
+ _end = embed_dim * 2
+ _w = in_proj_weight[_start:_end, :]
+ if _b is not None:
+ _b = _b[_start:_end]
+ k = F.linear(key, _w, _b)
+ q_k = F.linear(query, _w, _b)
+ # This is inline in_proj function with
+ # in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = embed_dim * 2
+ _end = None
+ _w = in_proj_weight[_start:, :]
+ if _b is not None:
+ _b = _b[_start:]
+ v = F.linear(value, _w, _b)
+ q_v = F.linear(query, _w, _b)
+ else:
+ q_proj_weight_non_opt = \
+ torch.jit._unwrap_optional(q_proj_weight)
+ len1, len2 = q_proj_weight_non_opt.size()
+ assert len1 == embed_dim and len2 == query.size(-1)
+
+ k_proj_weight_non_opt = \
+ torch.jit._unwrap_optional(k_proj_weight)
+ len1, len2 = k_proj_weight_non_opt.size()
+ assert len1 == embed_dim and len2 == key.size(-1)
+
+ v_proj_weight_non_opt = \
+ torch.jit._unwrap_optional(v_proj_weight)
+ len1, len2 = v_proj_weight_non_opt.size()
+ assert len1 == embed_dim and len2 == value.size(-1)
+
+ if in_proj_bias is not None:
+ q = F.linear(query, q_proj_weight_non_opt,
+ in_proj_bias[0:embed_dim])
+ k = F.linear(key, k_proj_weight_non_opt,
+ in_proj_bias[embed_dim:(embed_dim * 2)])
+ v = F.linear(value, v_proj_weight_non_opt,
+ in_proj_bias[(embed_dim * 2):])
+ else:
+ q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
+ k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
+ v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
+ q = q * scaling
+
+ if attn_mask is not None:
+ assert (
+ attn_mask.dtype == torch.float32
+ or attn_mask.dtype == torch.float64
+ or attn_mask.dtype == torch.float16
+ or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool
+ ), 'Only float, byte, and bool types are supported for ' \
+ 'attn_mask, not {}'.format(attn_mask.dtype)
+ if attn_mask.dtype == torch.uint8:
+ warnings.warn('Byte tensor for attn_mask in nn.MultiheadAttention '
+ 'is deprecated. Use bool tensor instead.')
+ attn_mask = attn_mask.to(torch.bool)
+
+ if attn_mask.dim() == 2:
+ attn_mask = attn_mask.unsqueeze(0)
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
+ raise RuntimeError(
+ 'The size of the 2D attn_mask is not correct.')
+ elif attn_mask.dim() == 3:
+ if list(attn_mask.size()) != [
+ bsz * num_heads,
+ query.size(0), key.size(0)
+ ]:
+ raise RuntimeError(
+ 'The size of the 3D attn_mask is not correct.')
+ else:
+ raise RuntimeError(
+ "attn_mask's dimension {} is not supported".format(
+ attn_mask.dim()))
+ # attn_mask's dim is 3 now.
+
+ # convert ByteTensor key_padding_mask to bool
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+ warnings.warn(
+ 'Byte tensor for key_padding_mask in nn.MultiheadAttention '
+ 'is deprecated. Use bool tensor instead.')
+ key_padding_mask = key_padding_mask.to(torch.bool)
+
+ if bias_k is not None and bias_v is not None:
+ if static_k is None and static_v is None:
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = F.pad(attn_mask, (0, 1))
+ if key_padding_mask is not None:
+ key_padding_mask = F.pad(key_padding_mask, (0, 1))
+ else:
+ assert static_k is None, 'bias cannot be added to static key.'
+ assert static_v is None, 'bias cannot be added to static value.'
+ else:
+ assert bias_k is None
+ assert bias_v is None
+
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
+ if k is not None:
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
+ q_k = q_k.contiguous().view(tgt_len, bsz * num_heads,
+ head_dim).transpose(0, 1)
+ if v is not None:
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
+ q_v = q_v.contiguous().view(tgt_len, bsz * num_heads,
+ head_dim).transpose(0, 1)
+
+ if static_k is not None:
+ assert static_k.size(0) == bsz * num_heads
+ assert static_k.size(2) == head_dim
+ k = static_k
+
+ if static_v is not None:
+ assert static_v.size(0) == bsz * num_heads
+ assert static_v.size(2) == head_dim
+ v = static_v
+
+ src_len = k.size(1)
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if add_zero_attn:
+ src_len += 1
+ k = torch.cat(
+ [
+ k,
+ torch.zeros(
+ (k.size(0), 1) + k.size()[2:],
+ dtype=k.dtype,
+ device=k.device),
+ ],
+ dim=1,
+ )
+ v = torch.cat(
+ [
+ v,
+ torch.zeros(
+ (v.size(0), 1) + v.size()[2:],
+ dtype=v.dtype,
+ device=v.device),
+ ],
+ dim=1,
+ )
+ if attn_mask is not None:
+ attn_mask = F.pad(attn_mask, (0, 1))
+ if key_padding_mask is not None:
+ key_padding_mask = F.pad(key_padding_mask, (0, 1))
+
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
+ assert list(
+ attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_output_weights.masked_fill_(attn_mask, float('-inf'))
+ else:
+ attn_output_weights += attn_mask
+
+ if key_padding_mask is not None:
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len,
+ src_len)
+ attn_output_weights = attn_output_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
+ float('-inf'),
+ )
+ attn_output_weights = attn_output_weights.view(bsz * num_heads,
+ tgt_len, src_len)
+ # attn_out_weights: [bsz * num_heads, tgt_len, src_len]
+ # ->[bsz * num_heads, tgt_len, src_len+1]
+ self_weight = (q * q_k).sum(
+ dim=-1, keepdim=True) # [bsz * num_heads, tgt_len, 1]
+ total_attn_output_weights = torch.cat([attn_output_weights, self_weight],
+ dim=-1)
+ total_attn_output_weights = F.softmax(total_attn_output_weights, dim=-1)
+ total_attn_output_weights = F.dropout(
+ total_attn_output_weights, p=dropout_p, training=training)
+ attn_output_weights = \
+ total_attn_output_weights[:, :, : -1]
+ # [bsz * num_heads, tgt_len, src_len]
+ self_weight = \
+ total_attn_output_weights[:, :, -1:] # [bsz * num_heads, tgt_len, 1]
+
+ attn_output = torch.bmm(attn_output_weights,
+ v) # [bsz * num_heads, tgt_len, head_dim]
+ attn_output = (attn_output + self_weight * q_v
+ ) # [bsz * num_heads, tgt_len, head_dim]
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
+ attn_output = attn_output.transpose(0, 1).contiguous().view(
+ tgt_len, bsz, embed_dim)
+ attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
+
+ if need_weights:
+ # average attention weights over heads
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len,
+ src_len)
+ return attn_output, attn_output_weights # .sum(dim=1) / num_heads
+ else:
+ return attn_output, None
+
+
+def cross_attn_layer(tf_layer: BaseTransformerLayer, x, mem, attn_bias):
+ """Implementation of transformer layer with cross attention. The cross
+ attention shares the embedding weights with self-attention of tf_layer.
+ Args:
+ tf_layer: (TransformerEncoderLayer): The Module of transformer layer.
+ x (Tensor): query [K,N,C]
+ mem (Tensor): key and value [L,N,C]
+ attn_bias (Tensor): attention bias [N*num_head,K,L]
+
+ Return:
+ x (Tensor): cross attention output [K,N,C]
+ """
+ self_attn_layer = tf_layer.attentions[0].attn
+ attn_layer_paras = {
+ 'embed_dim_to_check': self_attn_layer.embed_dim,
+ 'num_heads': self_attn_layer.num_heads,
+ 'in_proj_weight': self_attn_layer.in_proj_weight,
+ 'in_proj_bias': self_attn_layer.in_proj_bias,
+ 'bias_k': self_attn_layer.bias_k,
+ 'bias_v': self_attn_layer.bias_v,
+ 'add_zero_attn': self_attn_layer.add_zero_attn,
+ 'dropout_p': self_attn_layer.dropout,
+ 'out_proj_weight': self_attn_layer.out_proj.weight,
+ 'out_proj_bias': self_attn_layer.out_proj.bias,
+ 'training': self_attn_layer.training
+ }
+
+ q_x = tf_layer.norms[0](x)
+ k_x = v_x = tf_layer.norms[0](mem)
+ x = x + cross_attn_with_self_bias(
+ q_x,
+ k_x,
+ v_x,
+ attn_mask=attn_bias,
+ need_weights=False,
+ **attn_layer_paras)[0]
+ x = tf_layer.ffns[0](tf_layer.norms[1](x), identity=x)
+ return x
+
+
+class LayerNorm2d(nn.Module):
+ """A LayerNorm variant, popularized by Transformers, that performs point-
+ wise mean and variance normalization over the channel dimension for inputs
+ that have shape (batch_size, channels, height, width).
+
+ https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950
+ """
+
+ def __init__(self, normalized_shape, eps=1e-6):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
+ self.eps = eps
+ self.normalized_shape = (normalized_shape, )
+
+ def forward(self, x: torch.Tensor):
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
+
+
+class MLP(nn.Module):
+ """Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self,
+ input_dim,
+ hidden_dim,
+ output_dim,
+ num_layers,
+ affine_func=nn.Linear):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(
+ affine_func(n, k)
+ for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x: torch.Tensor):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
diff --git a/mmseg/registry/registry.py b/mmseg/registry/registry.py
index 32684e758f9..37b6a776095 100644
--- a/mmseg/registry/registry.py
+++ b/mmseg/registry/registry.py
@@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
-"""MMSegmentation provides 17 registry nodes to support using modules across
+"""MMSegmentation provides 21 registry nodes to support using modules across
projects. Each node is a child of the root registry in MMEngine.
More details can be found at
-https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
+https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html.
"""
from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS
@@ -46,10 +46,7 @@
# manage data-related modules
DATASETS = Registry(
'dataset', parent=MMENGINE_DATASETS, locations=['mmseg.datasets'])
-DATA_SAMPLERS = Registry(
- 'data sampler',
- parent=MMENGINE_DATA_SAMPLERS,
- locations=['mmseg.datasets.samplers'])
+DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS)
TRANSFORMS = Registry(
'transform',
parent=MMENGINE_TRANSFORMS,
diff --git a/mmseg/utils/__init__.py b/mmseg/utils/__init__.py
index cb1436c1980..0a2af58c6e0 100644
--- a/mmseg/utils/__init__.py
+++ b/mmseg/utils/__init__.py
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
# yapf: disable
-from .class_names import (ade_classes, ade_palette, cityscapes_classes,
+from .class_names import (ade_classes, ade_palette, bdd100k_classes,
+ bdd100k_palette, cityscapes_classes,
cityscapes_palette, cocostuff_classes,
cocostuff_palette, dataset_aliases, get_classes,
get_palette, isaid_classes, isaid_palette,
@@ -10,22 +11,60 @@
vaihingen_palette, voc_classes, voc_palette)
# yapf: enable
from .collect_env import collect_env
+from .get_templates import get_predefined_templates
from .io import datafrombytes
from .misc import add_prefix, stack_batch
from .set_env import register_all_modules
+from .tokenizer import tokenize
from .typing_utils import (ConfigType, ForwardResults, MultiConfig,
OptConfigType, OptMultiConfig, OptSampleList,
SampleList, TensorDict, TensorList)
+# isort: off
+from .mask_classification import MatchMasks, seg_data_to_instance_data
+
__all__ = [
- 'collect_env', 'register_all_modules', 'stack_batch', 'add_prefix',
- 'ConfigType', 'OptConfigType', 'MultiConfig', 'OptMultiConfig',
- 'SampleList', 'OptSampleList', 'TensorDict', 'TensorList',
- 'ForwardResults', 'cityscapes_classes', 'ade_classes', 'voc_classes',
- 'cocostuff_classes', 'loveda_classes', 'potsdam_classes',
- 'vaihingen_classes', 'isaid_classes', 'stare_classes',
- 'cityscapes_palette', 'ade_palette', 'voc_palette', 'cocostuff_palette',
- 'loveda_palette', 'potsdam_palette', 'vaihingen_palette', 'isaid_palette',
- 'stare_palette', 'dataset_aliases', 'get_classes', 'get_palette',
- 'datafrombytes', 'synapse_palette', 'synapse_classes'
+ 'collect_env',
+ 'register_all_modules',
+ 'stack_batch',
+ 'add_prefix',
+ 'ConfigType',
+ 'OptConfigType',
+ 'MultiConfig',
+ 'OptMultiConfig',
+ 'SampleList',
+ 'OptSampleList',
+ 'TensorDict',
+ 'TensorList',
+ 'ForwardResults',
+ 'cityscapes_classes',
+ 'ade_classes',
+ 'voc_classes',
+ 'cocostuff_classes',
+ 'loveda_classes',
+ 'potsdam_classes',
+ 'vaihingen_classes',
+ 'isaid_classes',
+ 'stare_classes',
+ 'cityscapes_palette',
+ 'ade_palette',
+ 'voc_palette',
+ 'cocostuff_palette',
+ 'loveda_palette',
+ 'potsdam_palette',
+ 'vaihingen_palette',
+ 'isaid_palette',
+ 'stare_palette',
+ 'dataset_aliases',
+ 'get_classes',
+ 'get_palette',
+ 'datafrombytes',
+ 'synapse_palette',
+ 'synapse_classes',
+ 'get_predefined_templates',
+ 'tokenize',
+ 'seg_data_to_instance_data',
+ 'MatchMasks',
+ 'bdd100k_classes',
+ 'bdd100k_palette',
]
diff --git a/mmseg/utils/bpe_simple_vocab_16e6.txt.gz b/mmseg/utils/bpe_simple_vocab_16e6.txt.gz
new file mode 100644
index 00000000000..7b5088a527f
Binary files /dev/null and b/mmseg/utils/bpe_simple_vocab_16e6.txt.gz differ
diff --git a/mmseg/utils/class_names.py b/mmseg/utils/class_names.py
index 961a08520d2..5ab35f99dca 100644
--- a/mmseg/utils/class_names.py
+++ b/mmseg/utils/class_names.py
@@ -52,6 +52,21 @@ def voc_classes():
]
+def pcontext_classes():
+ """Pascal Context class names for external use."""
+ return [
+ 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 'bird',
+ 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet', 'car', 'cat',
+ 'ceiling', 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain',
+ 'dog', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground',
+ 'horse', 'keyboard', 'light', 'motorbike', 'mountain', 'mouse',
+ 'person', 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep',
+ 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', 'track',
+ 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window',
+ 'wood'
+ ]
+
+
def cocostuff_classes():
"""CocoStuff class names for external use."""
return [
@@ -306,6 +321,25 @@ def voc_palette():
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
+def pcontext_palette():
+ """Pascal Context palette for external use."""
+ return [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
+ [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230],
+ [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61],
+ [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140],
+ [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200],
+ [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71],
+ [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92],
+ [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6],
+ [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8],
+ [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8],
+ [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255],
+ [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140],
+ [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0],
+ [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0],
+ [0, 235, 255], [0, 173, 255], [31, 0, 255]]
+
+
def cocostuff_palette():
"""CocoStuff palette for external use."""
return [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],
@@ -419,10 +453,31 @@ def lip_palette():
]
+def bdd100k_classes():
+ """BDD100K class names for external use(the class name is compatible with
+ Cityscapes )."""
+ return [
+ 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
+ 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
+ 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
+ 'bicycle'
+ ]
+
+
+def bdd100k_palette():
+ """bdd100k palette for external use(same with cityscapes)"""
+ return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
+ [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
+ [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
+ [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
+ [0, 0, 230], [119, 11, 32]]
+
+
dataset_aliases = {
'cityscapes': ['cityscapes'],
'ade': ['ade', 'ade20k'],
'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug'],
+ 'pcontext': ['pcontext', 'pascal_context', 'voc2010'],
'loveda': ['loveda'],
'potsdam': ['potsdam'],
'vaihingen': ['vaihingen'],
@@ -435,7 +490,8 @@ def lip_palette():
'stare': ['stare', 'STARE'],
'lip': ['LIP', 'lip'],
'mapillary_v1': ['mapillary_v1'],
- 'mapillary_v2': ['mapillary_v2']
+ 'mapillary_v2': ['mapillary_v2'],
+ 'bdd100k': ['bdd100k']
}
diff --git a/mmseg/utils/get_templates.py b/mmseg/utils/get_templates.py
new file mode 100644
index 00000000000..7e9032ba96c
--- /dev/null
+++ b/mmseg/utils/get_templates.py
@@ -0,0 +1,109 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List
+
+PREDEFINED_TEMPLATES = {
+ 'imagenet': [
+ 'a bad photo of a {}.',
+ 'a photo of many {}.',
+ 'a sculpture of a {}.',
+ 'a photo of the hard to see {}.',
+ 'a low resolution photo of the {}.',
+ 'a rendering of a {}.',
+ 'graffiti of a {}.',
+ 'a bad photo of the {}.',
+ 'a cropped photo of the {}.',
+ 'a tattoo of a {}.',
+ 'the embroidered {}.',
+ 'a photo of a hard to see {}.',
+ 'a bright photo of a {}.',
+ 'a photo of a clean {}.',
+ 'a photo of a dirty {}.',
+ 'a dark photo of the {}.',
+ 'a drawing of a {}.',
+ 'a photo of my {}.',
+ 'the plastic {}.',
+ 'a photo of the cool {}.',
+ 'a close-up photo of a {}.',
+ 'a black and white photo of the {}.',
+ 'a painting of the {}.',
+ 'a painting of a {}.',
+ 'a pixelated photo of the {}.',
+ 'a sculpture of the {}.',
+ 'a bright photo of the {}.',
+ 'a cropped photo of a {}.',
+ 'a plastic {}.',
+ 'a photo of the dirty {}.',
+ 'a jpeg corrupted photo of a {}.',
+ 'a blurry photo of the {}.',
+ 'a photo of the {}.',
+ 'a good photo of the {}.',
+ 'a rendering of the {}.',
+ 'a {} in a video game.',
+ 'a photo of one {}.',
+ 'a doodle of a {}.',
+ 'a close-up photo of the {}.',
+ 'a photo of a {}.',
+ 'the origami {}.',
+ 'the {} in a video game.',
+ 'a sketch of a {}.',
+ 'a doodle of the {}.',
+ 'a origami {}.',
+ 'a low resolution photo of a {}.',
+ 'the toy {}.',
+ 'a rendition of the {}.',
+ 'a photo of the clean {}.',
+ 'a photo of a large {}.',
+ 'a rendition of a {}.',
+ 'a photo of a nice {}.',
+ 'a photo of a weird {}.',
+ 'a blurry photo of a {}.',
+ 'a cartoon {}.',
+ 'art of a {}.',
+ 'a sketch of the {}.',
+ 'a embroidered {}.',
+ 'a pixelated photo of a {}.',
+ 'itap of the {}.',
+ 'a jpeg corrupted photo of the {}.',
+ 'a good photo of a {}.',
+ 'a plushie {}.',
+ 'a photo of the nice {}.',
+ 'a photo of the small {}.',
+ 'a photo of the weird {}.',
+ 'the cartoon {}.',
+ 'art of the {}.',
+ 'a drawing of the {}.',
+ 'a photo of the large {}.',
+ 'a black and white photo of a {}.',
+ 'the plushie {}.',
+ 'a dark photo of a {}.',
+ 'itap of a {}.',
+ 'graffiti of the {}.',
+ 'a toy {}.',
+ 'itap of my {}.',
+ 'a photo of a cool {}.',
+ 'a photo of a small {}.',
+ 'a tattoo of the {}.',
+ ],
+ 'vild': [
+ 'a photo of a {}.',
+ 'This is a photo of a {}',
+ 'There is a {} in the scene',
+ 'There is the {} in the scene',
+ 'a photo of a {} in the scene',
+ 'a photo of a small {}.',
+ 'a photo of a medium {}.',
+ 'a photo of a large {}.',
+ 'This is a photo of a small {}.',
+ 'This is a photo of a medium {}.',
+ 'This is a photo of a large {}.',
+ 'There is a small {} in the scene.',
+ 'There is a medium {} in the scene.',
+ 'There is a large {} in the scene.',
+ ],
+}
+
+
+def get_predefined_templates(template_set_name: str) -> List[str]:
+ if template_set_name not in PREDEFINED_TEMPLATES:
+ raise ValueError(f'Template set {template_set_name} not found')
+ return PREDEFINED_TEMPLATES[template_set_name]
diff --git a/mmseg/utils/io.py b/mmseg/utils/io.py
index d03517401c5..7029c3cddda 100644
--- a/mmseg/utils/io.py
+++ b/mmseg/utils/io.py
@@ -3,6 +3,7 @@
import io
import pickle
+import cv2
import numpy as np
@@ -12,7 +13,7 @@ def datafrombytes(content: bytes, backend: str = 'numpy') -> np.ndarray:
Args:
content (bytes): The data bytes got from files or other streams.
backend (str): The data decoding backend type. Options are 'numpy',
- 'nifti' and 'pickle'. Defaults to 'numpy'.
+ 'nifti', 'cv2' and 'pickle'. Defaults to 'numpy'.
Returns:
numpy.ndarray: Loaded data array.
@@ -33,6 +34,9 @@ def datafrombytes(content: bytes, backend: str = 'numpy') -> np.ndarray:
data = Nifti1Image.from_bytes(data.to_bytes()).get_fdata()
elif backend == 'numpy':
data = np.load(f)
+ elif backend == 'cv2':
+ data = np.frombuffer(f.read(), dtype=np.uint8)
+ data = cv2.imdecode(data, cv2.IMREAD_UNCHANGED)
else:
raise ValueError
return data
diff --git a/mmseg/utils/mask_classification.py b/mmseg/utils/mask_classification.py
new file mode 100644
index 00000000000..205d5259754
--- /dev/null
+++ b/mmseg/utils/mask_classification.py
@@ -0,0 +1,205 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Tuple
+
+import torch
+from mmcv.ops import point_sample
+from mmengine.structures import InstanceData
+from torch import Tensor
+
+from mmseg.registry import TASK_UTILS
+from mmseg.utils import ConfigType, SampleList
+
+
+def seg_data_to_instance_data(ignore_index: int,
+ batch_data_samples: SampleList):
+ """Convert the paradigm of ground truth from semantic segmentation to
+ instance segmentation.
+
+ Args:
+ ignore_index (int): The label index to be ignored.
+ batch_data_samples (List[SegDataSample]): The Data
+ Samples. It usually includes information such as
+ `gt_sem_seg`.
+
+ Returns:
+ tuple[Tensor]: A tuple contains two lists.
+ - batch_gt_instances (List[InstanceData]): Batch of
+ gt_instance. It usually includes ``labels``, each is
+ unique ground truth label id of images, with
+ shape (num_gt, ) and ``masks``, each is ground truth
+ masks of each instances of a image, shape (num_gt, h, w).
+ - batch_img_metas (List[Dict]): List of image meta information.
+ """
+ batch_gt_instances = []
+
+ for data_sample in batch_data_samples:
+ gt_sem_seg = data_sample.gt_sem_seg.data
+ classes = torch.unique(
+ gt_sem_seg,
+ sorted=False,
+ return_inverse=False,
+ return_counts=False)
+
+ # remove ignored region
+ gt_labels = classes[classes != ignore_index]
+
+ masks = []
+ for class_id in gt_labels:
+ masks.append(gt_sem_seg == class_id)
+
+ if len(masks) == 0:
+ gt_masks = torch.zeros(
+ (0, gt_sem_seg.shape[-2],
+ gt_sem_seg.shape[-1])).to(gt_sem_seg).long()
+ else:
+ gt_masks = torch.stack(masks).squeeze(1).long()
+
+ instance_data = InstanceData(labels=gt_labels, masks=gt_masks)
+ batch_gt_instances.append(instance_data)
+ return batch_gt_instances
+
+
+class MatchMasks:
+ """Match the predictions to category labels.
+
+ Args:
+ num_points (int): the number of sampled points to compute cost.
+ num_queries (int): the number of prediction masks.
+ num_classes (int): the number of classes.
+ assigner (BaseAssigner): the assigner to compute matching.
+ """
+
+ def __init__(self,
+ num_points: int,
+ num_queries: int,
+ num_classes: int,
+ assigner: ConfigType = None):
+ assert assigner is not None, "\'assigner\' in decode_head.train_cfg" \
+ 'cannot be None'
+ assert num_points > 0, 'num_points should be a positive integer.'
+ self.num_points = num_points
+ self.num_queries = num_queries
+ self.num_classes = num_classes
+ self.assigner = TASK_UTILS.build(assigner)
+
+ def get_targets(self, cls_scores: List[Tensor], mask_preds: List[Tensor],
+ batch_gt_instances: List[InstanceData]) -> Tuple:
+ """Compute best mask matches for all images for a decoder layer.
+
+ Args:
+ cls_scores (List[Tensor]): Mask score logits from a single
+ decoder layer for all images. Each with shape (num_queries,
+ cls_out_channels).
+ mask_preds (List[Tensor]): Mask logits from a single decoder
+ layer for all images. Each with shape (num_queries, h, w).
+ batch_gt_instances (List[InstanceData]): each contains
+ ``labels`` and ``masks``.
+
+ Returns:
+ tuple: a tuple containing the following targets.
+
+ - labels (List[Tensor]): Labels of all images.\
+ Each with shape (num_queries, ).
+ - mask_targets (List[Tensor]): Mask targets of\
+ all images. Each with shape (num_queries, h, w).
+ - mask_weights (List[Tensor]): Mask weights of\
+ all images. Each with shape (num_queries, ).
+ - avg_factor (int): Average factor that is used to
+ average the loss. `avg_factor` is usually equal
+ to the number of positive priors.
+ """
+ batch_size = cls_scores.shape[0]
+ results = dict({
+ 'labels': [],
+ 'mask_targets': [],
+ 'mask_weights': [],
+ })
+ for i in range(batch_size):
+ labels, mask_targets, mask_weights\
+ = self._get_targets_single(cls_scores[i],
+ mask_preds[i],
+ batch_gt_instances[i])
+ results['labels'].append(labels)
+ results['mask_targets'].append(mask_targets)
+ results['mask_weights'].append(mask_weights)
+
+ # shape (batch_size, num_queries)
+ labels = torch.stack(results['labels'], dim=0)
+ # shape (batch_size, num_gts, h, w)
+ mask_targets = torch.cat(results['mask_targets'], dim=0)
+ # shape (batch_size, num_queries)
+ mask_weights = torch.stack(results['mask_weights'], dim=0)
+
+ avg_factor = sum(
+ [len(gt_instances.labels) for gt_instances in batch_gt_instances])
+
+ res = (labels, mask_targets, mask_weights, avg_factor)
+
+ return res
+
+ def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor,
+ gt_instances: InstanceData) \
+ -> Tuple[Tensor, Tensor, Tensor]:
+ """Compute a set of best mask matches for one image.
+
+ Args:
+ cls_score (Tensor): Mask score logits from a single decoder layer
+ for one image. Shape (num_queries, cls_out_channels).
+ mask_pred (Tensor): Mask logits for a single decoder layer for one
+ image. Shape (num_queries, h, w).
+ gt_instances (:obj:`InstanceData`): It contains ``labels`` and
+ ``masks``.
+
+ Returns:
+ tuple[Tensor]: A tuple containing the following for one image.
+
+ - labels (Tensor): Labels of each image. \
+ shape (num_queries, ).
+ - mask_targets (Tensor): Mask targets of each image. \
+ shape (num_queries, h, w).
+ - mask_weights (Tensor): Mask weights of each image. \
+ shape (num_queries, ).
+ """
+ gt_labels = gt_instances.labels
+ gt_masks = gt_instances.masks
+ # when "gt_labels" is empty, classify all queries to background
+ if len(gt_labels) == 0:
+ labels = gt_labels.new_full((self.num_queries, ),
+ self.num_classes,
+ dtype=torch.long)
+ mask_targets = gt_labels
+ mask_weights = gt_labels.new_zeros((self.num_queries, ))
+ return labels, mask_targets, mask_weights
+ # sample points
+ num_queries = cls_score.shape[0]
+ num_gts = gt_labels.shape[0]
+
+ point_coords = torch.rand((1, self.num_points, 2),
+ device=cls_score.device)
+ # shape (num_queries, num_points)
+ mask_points_pred = point_sample(
+ mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1,
+ 1)).squeeze(1)
+ # shape (num_gts, num_points)
+ gt_points_masks = point_sample(
+ gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1,
+ 1)).squeeze(1)
+
+ sampled_gt_instances = InstanceData(
+ labels=gt_labels, masks=gt_points_masks)
+ sampled_pred_instances = InstanceData(
+ scores=cls_score, masks=mask_points_pred)
+ # assign and sample
+ matched_quiery_inds, matched_label_inds = self.assigner.assign(
+ pred_instances=sampled_pred_instances,
+ gt_instances=sampled_gt_instances)
+ labels = gt_labels.new_full((self.num_queries, ),
+ self.num_classes,
+ dtype=torch.long)
+ labels[matched_quiery_inds] = gt_labels[matched_label_inds]
+
+ mask_weights = gt_labels.new_zeros((self.num_queries, ))
+ mask_weights[matched_quiery_inds] = 1
+ mask_targets = gt_masks[matched_label_inds]
+
+ return labels, mask_targets, mask_weights
diff --git a/mmseg/utils/misc.py b/mmseg/utils/misc.py
index 0a561732e9a..dfc469e8320 100644
--- a/mmseg/utils/misc.py
+++ b/mmseg/utils/misc.py
@@ -94,18 +94,28 @@ def stack_batch(inputs: List[torch.Tensor],
# pad gt_sem_seg
if data_samples is not None:
data_sample = data_samples[i]
- gt_sem_seg = data_sample.gt_sem_seg.data
- del data_sample.gt_sem_seg.data
- data_sample.gt_sem_seg.data = F.pad(
- gt_sem_seg, padding_size, value=seg_pad_val)
+ pad_shape = None
+ if 'gt_sem_seg' in data_sample:
+ gt_sem_seg = data_sample.gt_sem_seg.data
+ del data_sample.gt_sem_seg.data
+ data_sample.gt_sem_seg.data = F.pad(
+ gt_sem_seg, padding_size, value=seg_pad_val)
+ pad_shape = data_sample.gt_sem_seg.shape
if 'gt_edge_map' in data_sample:
gt_edge_map = data_sample.gt_edge_map.data
del data_sample.gt_edge_map.data
data_sample.gt_edge_map.data = F.pad(
gt_edge_map, padding_size, value=seg_pad_val)
+ pad_shape = data_sample.gt_edge_map.shape
+ if 'gt_depth_map' in data_sample:
+ gt_depth_map = data_sample.gt_depth_map.data
+ del data_sample.gt_depth_map.data
+ data_sample.gt_depth_map.data = F.pad(
+ gt_depth_map, padding_size, value=seg_pad_val)
+ pad_shape = data_sample.gt_depth_map.shape
data_sample.set_metainfo({
'img_shape': tensor.shape[-2:],
- 'pad_shape': data_sample.gt_sem_seg.shape,
+ 'pad_shape': pad_shape,
'padding_size': padding_size
})
padded_samples.append(data_sample)
diff --git a/mmseg/utils/tokenizer.py b/mmseg/utils/tokenizer.py
new file mode 100644
index 00000000000..d56f5fae602
--- /dev/null
+++ b/mmseg/utils/tokenizer.py
@@ -0,0 +1,240 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""CLIP tokenizer.
+
+Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright
+(c) 2021 OpenAI.
+"""
+import gzip
+import html
+import os
+from functools import lru_cache
+from typing import List, Union
+
+import ftfy
+import regex as re
+import torch
+
+os.environ['TOKENIZERS_PARALLELISM'] = 'false'
+
+
+@lru_cache()
+def default_bpe():
+ return os.path.join(
+ os.path.dirname(os.path.abspath(__file__)),
+ 'bpe_simple_vocab_16e6.txt.gz')
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """Returns list of utf-8 byte and a corresponding list of unicode strings.
+
+ The reversible bpe codes work on unicode strings. This means you need a
+ large # of unicode characters in your vocab if you want to avoid UNKs. When
+ you're at something like a 10B token dataset you end up needing around 5K
+ for decent coverage. This is a significant percentage of your normal, say,
+ 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and
+ unicode strings. And avoids mapping to whitespace/control characters the
+ bpe code barfs on.
+ """
+ bs = list(range(ord('!'),
+ ord('~') + 1)) + list(range(
+ ord('¡'),
+ ord('¬') + 1)) + list(range(ord('®'),
+ ord('ÿ') + 1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+
+ Word is represented as tuple of symbols (symbols being variable-length
+ strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer:
+
+ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
+ merges = merges[1:49152 - 256 - 2 + 1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v + '' for v in vocab]
+ for merge in merges:
+ vocab.append(''.join(merge))
+ if not special_tokens:
+ special_tokens = ['', '']
+ else:
+ special_tokens = ['', ''
+ ] + special_tokens
+ vocab.extend(special_tokens)
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {t: t for t in special_tokens}
+ special = '|'.join(special_tokens)
+ self.pat = re.compile(
+ special +
+ r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
+ re.IGNORECASE)
+
+ self.vocab_size = len(self.encoder)
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + (token[-1] + '', )
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token + ''
+
+ while True:
+ bigram = min(
+ pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except: # noqa: E722, E261
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word) - 1 and word[
+ i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = ''.join(self.byte_encoder[b]
+ for b in token.encode('utf-8'))
+ bpe_tokens.extend(self.encoder[bpe_token]
+ for bpe_token in self.bpe(token).split(' '))
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = ''.join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode(
+ 'utf-8', errors='replace').replace('', ' ')
+ return text
+
+
+_tokenizer = SimpleTokenizer()
+
+
+def decode(output_ids: torch.Tensor):
+ output_ids = output_ids.cpu().numpy()
+ return _tokenizer.decode(output_ids)
+
+
+def tokenize(texts: Union[str, List[str]],
+ context_length: int = 77) -> torch.LongTensor:
+ """Returns the tokenized representation of given input string(s)
+
+ Parameters
+ ----------
+ texts : Union[str, List[str]]
+ An input string or a list of input strings to tokenize
+ context_length : int
+ The context length to use; all CLIP models use 77 as the context length
+
+ Returns
+ -------
+ A two-dimensional tensor containing the resulting tokens,
+ shape = [number of input strings, context_length]
+ """
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = _tokenizer.encoder['']
+ eot_token = _tokenizer.encoder['']
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
+ for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > context_length:
+ tokens = tokens[:context_length] # Truncate
+ tokens[-1] = eot_token
+ result[i, :len(tokens)] = torch.tensor(tokens)
+
+ return result
+
+
+class HFTokenizer:
+ """HuggingFace tokenizer wrapper."""
+
+ def __init__(self, tokenizer_name: str):
+ from transformers import AutoTokenizer
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
+
+ def save_pretrained(self, dest):
+ self.tokenizer.save_pretrained(dest)
+
+ def __call__(self,
+ texts: Union[str, List[str]],
+ context_length: int = 77) -> torch.Tensor:
+ # same cleaning as for default tokenizer, except lowercasing
+ # adding lower (for case-sensitive tokenizers) will make it
+ # more robust but less sensitive to nuance
+ if isinstance(texts, str):
+ texts = [texts]
+ texts = [whitespace_clean(basic_clean(text)) for text in texts]
+ input_ids = self.tokenizer(
+ texts,
+ return_tensors='pt',
+ max_length=context_length,
+ padding='max_length',
+ truncation=True,
+ ).input_ids
+ return input_ids
diff --git a/mmseg/version.py b/mmseg/version.py
index ef8e391a299..a654604da72 100644
--- a/mmseg/version.py
+++ b/mmseg/version.py
@@ -1,6 +1,6 @@
# Copyright (c) Open-MMLab. All rights reserved.
-__version__ = '1.0.0rc6'
+__version__ = '1.1.2'
def parse_version_info(version_str):
diff --git a/mmseg/visualization/local_visualizer.py b/mmseg/visualization/local_visualizer.py
index d11ad79c816..3096e3183bd 100644
--- a/mmseg/visualization/local_visualizer.py
+++ b/mmseg/visualization/local_visualizer.py
@@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional
+import cv2
import mmcv
import numpy as np
+import torch
from mmengine.dist import master_only
from mmengine.structures import PixelData
from mmengine.visualization import Visualizer
@@ -31,7 +33,7 @@ class SegLocalVisualizer(Visualizer):
`cityscapes` classes by default. Defaults to None.
palette (list, optional): Input palette for result rendering, which is
a list of color palette responding to the classes. Defaults to None.
- dataset_name (str, optional): `Dataset name or alias `_
+ dataset_name (str, optional): `Dataset name or alias `_
visulizer will use the meta information of the dataset i.e. classes
and palette, but the `classes` and `palette` have higher priority.
Defaults to None.
@@ -42,8 +44,8 @@ class SegLocalVisualizer(Visualizer):
>>> import numpy as np
>>> import torch
>>> from mmengine.structures import PixelData
- >>> from mmseg.data import SegDataSample
- >>> from mmseg.engine.visualization import SegLocalVisualizer
+ >>> from mmseg.structures import SegDataSample
+ >>> from mmseg.visualization import SegLocalVisualizer
>>> seg_local_visualizer = SegLocalVisualizer()
>>> image = np.random.randint(0, 256,
@@ -60,7 +62,7 @@ class SegLocalVisualizer(Visualizer):
>>> seg_local_visualizer.add_datasample(
... 'visualizer_example', image,
... gt_seg_data_sample, show=True)
- """ # noqa
+ """ # noqa
def __init__(self,
name: str = 'visualizer',
@@ -76,9 +78,32 @@ def __init__(self,
self.alpha: float = alpha
self.set_dataset_meta(palette, classes, dataset_name)
- def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData,
+ def _get_center_loc(self, mask: np.ndarray) -> np.ndarray:
+ """Get semantic seg center coordinate.
+
+ Args:
+ mask: np.ndarray: get from sem_seg
+ """
+ loc = np.argwhere(mask == 1)
+
+ loc_sort = np.array(
+ sorted(loc.tolist(), key=lambda row: (row[0], row[1])))
+ y_list = loc_sort[:, 0]
+ unique, indices, counts = np.unique(
+ y_list, return_index=True, return_counts=True)
+ y_loc = unique[counts.argmax()]
+ y_most_freq_loc = loc[loc_sort[:, 0] == y_loc]
+ center_num = len(y_most_freq_loc) // 2
+ x = y_most_freq_loc[center_num][1]
+ y = y_most_freq_loc[center_num][0]
+ return np.array([x, y])
+
+ def _draw_sem_seg(self,
+ image: np.ndarray,
+ sem_seg: PixelData,
classes: Optional[List],
- palette: Optional[List]) -> np.ndarray:
+ palette: Optional[List],
+ withLabels: Optional[bool] = True) -> np.ndarray:
"""Draw semantic seg of GT or prediction.
Args:
@@ -94,6 +119,8 @@ def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData,
palette (list, optional): Input palette for result rendering, which
is a list of color palette responding to the classes.
Defaults to None.
+ withLabels(bool, optional): Add semantic labels in visualization
+ result, Default to True.
Returns:
np.ndarray: the drawn image which channel is RGB.
@@ -108,14 +135,90 @@ def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData,
colors = [palette[label] for label in labels]
- self.set_image(image)
-
- # draw semantic masks
+ mask = np.zeros_like(image, dtype=np.uint8)
for label, color in zip(labels, colors):
- self.draw_binary_masks(
- sem_seg == label, colors=[color], alphas=self.alpha)
+ mask[sem_seg[0] == label, :] = color
+
+ if withLabels:
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ # (0,1] to change the size of the text relative to the image
+ scale = 0.05
+ fontScale = min(image.shape[0], image.shape[1]) / (25 / scale)
+ fontColor = (255, 255, 255)
+ if image.shape[0] < 300 or image.shape[1] < 300:
+ thickness = 1
+ rectangleThickness = 1
+ else:
+ thickness = 2
+ rectangleThickness = 2
+ lineType = 2
+
+ if isinstance(sem_seg[0], torch.Tensor):
+ masks = sem_seg[0].numpy() == labels[:, None, None]
+ else:
+ masks = sem_seg[0] == labels[:, None, None]
+ masks = masks.astype(np.uint8)
+ for mask_num in range(len(labels)):
+ classes_id = labels[mask_num]
+ classes_color = colors[mask_num]
+ loc = self._get_center_loc(masks[mask_num])
+ text = classes[classes_id]
+ (label_width, label_height), baseline = cv2.getTextSize(
+ text, font, fontScale, thickness)
+ mask = cv2.rectangle(mask, loc,
+ (loc[0] + label_width + baseline,
+ loc[1] + label_height + baseline),
+ classes_color, -1)
+ mask = cv2.rectangle(mask, loc,
+ (loc[0] + label_width + baseline,
+ loc[1] + label_height + baseline),
+ (0, 0, 0), rectangleThickness)
+ mask = cv2.putText(mask, text, (loc[0], loc[1] + label_height),
+ font, fontScale, fontColor, thickness,
+ lineType)
+ color_seg = (image * (1 - self.alpha) + mask * self.alpha).astype(
+ np.uint8)
+ self.set_image(color_seg)
+ return color_seg
+
+ def _draw_depth_map(self, image: np.ndarray,
+ depth_map: PixelData) -> np.ndarray:
+ """Draws a depth map on a given image.
+
+ This function takes an image and a depth map as input,
+ renders the depth map, and concatenates it with the original image.
+ Finally, it updates the internal image state of the visualizer with
+ the concatenated result.
+
+ Args:
+ image (np.ndarray): The original image where the depth map will
+ be drawn. The array should be in the format HxWx3 where H is
+ the height, W is the width.
+
+ depth_map (PixelData): Depth map to be drawn. The depth map
+ should be in the form of a PixelData object. It will be
+ converted to a torch tensor if it is a numpy array.
+
+ Returns:
+ np.ndarray: The concatenated image with the depth map drawn.
+
+ Example:
+ >>> depth_map_data = PixelData(data=torch.rand(1, 10, 10))
+ >>> image = np.random.randint(0, 256,
+ >>> size=(10, 10, 3)).astype('uint8')
+ >>> visualizer = SegLocalVisualizer()
+ >>> visualizer._draw_depth_map(image, depth_map_data)
+ """
+ depth_map = depth_map.cpu().data
+ if isinstance(depth_map, np.ndarray):
+ depth_map = torch.from_numpy(depth_map)
+ if depth_map.ndim == 2:
+ depth_map = depth_map[None]
- return self.get_image()
+ depth_map = self.draw_featmap(depth_map, resize_shape=image.shape[:2])
+ out_image = np.concatenate((image, depth_map), axis=0)
+ self.set_image(out_image)
+ return out_image
def set_dataset_meta(self,
classes: Optional[List] = None,
@@ -133,11 +236,11 @@ def set_dataset_meta(self,
palette (list, optional): Input palette for result rendering, which
is a list of color palette responding to the classes.
Defaults to None.
- dataset_name (str, optional): `Dataset name or alias `_
+ dataset_name (str, optional): `Dataset name or alias `_
visulizer will use the meta information of the dataset i.e.
classes and palette, but the `classes` and `palette` have
higher priority. Defaults to None.
- """ # noqa
+ """ # noqa
# Set default value. When calling
# `SegLocalVisualizer().dataset_meta=xxx`,
# it will override the default value.
@@ -161,7 +264,8 @@ def add_datasample(
wait_time: float = 0,
# TODO: Supported in mmengine's Viusalizer.
out_file: Optional[str] = None,
- step: int = 0) -> None:
+ step: int = 0,
+ withLabels: Optional[bool] = True) -> None:
"""Draw datasample and save to all backends.
- If GT and prediction are plotted at the same time, they are
@@ -187,6 +291,8 @@ def add_datasample(
wait_time (float): The interval of show (s). Defaults to 0.
out_file (str): Path to output file. Defaults to None.
step (int): Global step value to record. Defaults to 0.
+ withLabels(bool, optional): Add semantic labels in visualization
+ result, Defaults to True.
"""
classes = self.dataset_meta.get('classes', None)
palette = self.dataset_meta.get('palette', None)
@@ -194,26 +300,38 @@ def add_datasample(
gt_img_data = None
pred_img_data = None
- if draw_gt and data_sample is not None and 'gt_sem_seg' in data_sample:
- gt_img_data = image
- assert classes is not None, 'class information is ' \
- 'not provided when ' \
- 'visualizing semantic ' \
- 'segmentation results.'
- gt_img_data = self._draw_sem_seg(gt_img_data,
- data_sample.gt_sem_seg, classes,
- palette)
-
- if (draw_pred and data_sample is not None
- and 'pred_sem_seg' in data_sample):
- pred_img_data = image
- assert classes is not None, 'class information is ' \
- 'not provided when ' \
- 'visualizing semantic ' \
- 'segmentation results.'
- pred_img_data = self._draw_sem_seg(pred_img_data,
- data_sample.pred_sem_seg,
- classes, palette)
+ if draw_gt and data_sample is not None:
+ if 'gt_sem_seg' in data_sample:
+ assert classes is not None, 'class information is ' \
+ 'not provided when ' \
+ 'visualizing semantic ' \
+ 'segmentation results.'
+ gt_img_data = self._draw_sem_seg(image, data_sample.gt_sem_seg,
+ classes, palette, withLabels)
+
+ if 'gt_depth_map' in data_sample:
+ gt_img_data = gt_img_data if gt_img_data is not None else image
+ gt_img_data = self._draw_depth_map(gt_img_data,
+ data_sample.gt_depth_map)
+
+ if draw_pred and data_sample is not None:
+
+ if 'pred_sem_seg' in data_sample:
+
+ assert classes is not None, 'class information is ' \
+ 'not provided when ' \
+ 'visualizing semantic ' \
+ 'segmentation results.'
+ pred_img_data = self._draw_sem_seg(image,
+ data_sample.pred_sem_seg,
+ classes, palette,
+ withLabels)
+
+ if 'pred_depth_map' in data_sample:
+ pred_img_data = pred_img_data if pred_img_data is not None \
+ else image
+ pred_img_data = self._draw_depth_map(
+ pred_img_data, data_sample.pred_depth_map)
if gt_img_data is not None and pred_img_data is not None:
drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1)
@@ -226,6 +344,6 @@ def add_datasample(
self.show(drawn_img, win_name=name, wait_time=wait_time)
if out_file is not None:
- mmcv.imwrite(mmcv.bgr2rgb(drawn_img), out_file)
+ mmcv.imwrite(mmcv.rgb2bgr(drawn_img), out_file)
else:
self.add_image(name, drawn_img, step)
diff --git a/model-index.yml b/model-index.yml
index 5e87c386ddf..4026bb9e6e4 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -8,6 +8,7 @@ Import:
- configs/cgnet/metafile.yaml
- configs/convnext/metafile.yaml
- configs/danet/metafile.yaml
+- configs/ddrnet/metafile.yaml
- configs/deeplabv3/metafile.yaml
- configs/deeplabv3plus/metafile.yaml
- configs/dmnet/metafile.yaml
@@ -37,6 +38,7 @@ Import:
- configs/psanet/metafile.yaml
- configs/pspnet/metafile.yaml
- configs/resnest/metafile.yaml
+- configs/san/metafile.yaml
- configs/segformer/metafile.yaml
- configs/segmenter/metafile.yaml
- configs/segnext/metafile.yaml
@@ -48,3 +50,4 @@ Import:
- configs/unet/metafile.yaml
- configs/upernet/metafile.yaml
- configs/vit/metafile.yaml
+- configs/vpd/metafile.yaml
diff --git a/projects/Adabins/README.md b/projects/Adabins/README.md
new file mode 100644
index 00000000000..8a23e92d74a
--- /dev/null
+++ b/projects/Adabins/README.md
@@ -0,0 +1,46 @@
+# AdaBins: Depth Estimation Using Adaptive Bins
+
+## Reference
+
+> [AdaBins: Depth Estimation Using Adaptive Bins](https://arxiv.org/abs/2011.14141)
+
+## Introduction
+
+Official Repo
+
+Code Snippet
+
+## Abstract
+
+We address the problem of estimating a high quality dense depth map from a single RGB input image. We start out with a baseline encoder-decoder convolutional neural network architecture and pose the question of how the global processing of information can help improve overall depth estimation. To this end, we propose a transformer-based architecture block that divides the depth range into bins whose center value is estimated adaptively per image. The final depth values are estimated as linear combinations of the bin centers. We call our new building block AdaBins. Our results show a decisive improvement over the state-of-the-art on several popular depth datasets across all metrics.We also validate the effectiveness of the proposed block with an ablation study and provide the code and corresponding pre-trained weights of the new state-of-the-art model.
+
+Our main contributions are the following:
+
+- We propose an architecture building block that performs global processing of the scene’s information.We propose to divide the predicted depth range into bins where the bin widths change per image. The final depth estimation is a linear combination of the bin center values.
+- We show a decisive improvement for supervised single image depth estimation across all metrics for the two most popular datasets, NYU and KITTI.
+- We analyze our findings and investigate different modifications on the proposed AdaBins block and study their effect on the accuracy of the depth estimation.
+
+
+
+
+
+## Performance
+
+### NYU and KITTI
+
+| Model | Encoder | Training epoch | Batchsize | Train Resolution | δ1 | δ2 | δ3 | REL | RMS | RMS log | params(M) | Links |
+| ------------- | --------------- | -------------- | --------- | ---------------- | ----- | ----- | ----- | ----- | ----- | ------- | --------- | ----------------------------------------------------------------------------------------------------------------------- |
+| AdaBins_nyu | EfficientNet-B5 | 25 | 16 | 416x544 | 0.903 | 0.984 | 0.997 | 0.103 | 0.364 | 0.044 | 78 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/adabins/adabins_efficient_b5_nyu_third-party-f68d6bd3.pth) |
+| AdaBins_kitti | EfficientNet-B5 | 25 | 16 | 352x764 | 0.964 | 0.995 | 0.999 | 0.058 | 2.360 | 0.088 | 78 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/adabins/adabins_efficient-b5_kitty_third-party-a1aa6f36.pth) |
+
+## Citation
+
+```bibtex
+@article{10.1109/cvpr46437.2021.00400,
+ author = {Bhat, S. A. and Alhashim, I. and Wonka, P.},
+ title = {Adabins: depth estimation using adaptive bins},
+ journal = {2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
+ year = {2021},
+ doi = {10.1109/cvpr46437.2021.00400}
+}
+```
diff --git a/projects/Adabins/backbones/__init__.py b/projects/Adabins/backbones/__init__.py
new file mode 100644
index 00000000000..04ae180be5f
--- /dev/null
+++ b/projects/Adabins/backbones/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .adabins_backbone import AdabinsBackbone
+
+__all__ = ['AdabinsBackbone']
diff --git a/projects/Adabins/backbones/adabins_backbone.py b/projects/Adabins/backbones/adabins_backbone.py
new file mode 100644
index 00000000000..07d73809e39
--- /dev/null
+++ b/projects/Adabins/backbones/adabins_backbone.py
@@ -0,0 +1,141 @@
+import timm
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, build_conv_layer
+from mmengine.model import BaseModule
+
+from mmseg.registry import MODELS
+
+
+class UpSampleBN(nn.Module):
+ """ UpSample module
+ Args:
+ skip_input (int): the input feature
+ output_features (int): the output feature
+ norm_cfg (dict, optional): Config dict for normalization layer.
+ Default: dict(type='BN', requires_grad=True).
+ act_cfg (dict, optional): The activation layer of AAM:
+ Aggregate Attention Module.
+ """
+
+ def __init__(self,
+ skip_input,
+ output_features,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='LeakyReLU')):
+ super().__init__()
+
+ self._net = nn.Sequential(
+ ConvModule(
+ in_channels=skip_input,
+ out_channels=output_features,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=True,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ ),
+ ConvModule(
+ in_channels=output_features,
+ out_channels=output_features,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=True,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ ))
+
+ def forward(self, x, concat_with):
+ up_x = F.interpolate(
+ x,
+ size=[concat_with.size(2),
+ concat_with.size(3)],
+ mode='bilinear',
+ align_corners=True)
+ f = torch.cat([up_x, concat_with], dim=1)
+ return self._net(f)
+
+
+class Encoder(nn.Module):
+ """ the efficientnet_b5 model
+ Args:
+ basemodel_name (str): the name of base model
+ """
+
+ def __init__(self, basemodel_name):
+ super().__init__()
+ self.original_model = timm.create_model(
+ basemodel_name, pretrained=True)
+ # Remove last layer
+ self.original_model.global_pool = nn.Identity()
+ self.original_model.classifier = nn.Identity()
+
+ def forward(self, x):
+ features = [x]
+ for k, v in self.original_model._modules.items():
+ if k == 'blocks':
+ for ki, vi in v._modules.items():
+ features.append(vi(features[-1]))
+ else:
+ features.append(v(features[-1]))
+ return features
+
+
+@MODELS.register_module()
+class AdabinsBackbone(BaseModule):
+ """ the backbone of the adabins
+ Args:
+ basemodel_name (str):the name of base model
+ num_features (int): the middle feature
+ num_classes (int): the classes number
+ bottleneck_features (int): the bottleneck features
+ conv_cfg (dict): Config dict for convolution layer.
+ """
+
+ def __init__(self,
+ basemodel_name,
+ num_features=2048,
+ num_classes=128,
+ bottleneck_features=2048,
+ conv_cfg=dict(type='Conv')):
+ super().__init__()
+ self.encoder = Encoder(basemodel_name)
+ features = int(num_features)
+ self.conv2 = build_conv_layer(
+ conv_cfg,
+ bottleneck_features,
+ features,
+ kernel_size=1,
+ stride=1,
+ padding=1)
+ self.up1 = UpSampleBN(
+ skip_input=features // 1 + 112 + 64, output_features=features // 2)
+ self.up2 = UpSampleBN(
+ skip_input=features // 2 + 40 + 24, output_features=features // 4)
+ self.up3 = UpSampleBN(
+ skip_input=features // 4 + 24 + 16, output_features=features // 8)
+ self.up4 = UpSampleBN(
+ skip_input=features // 8 + 16 + 8, output_features=features // 16)
+
+ self.conv3 = build_conv_layer(
+ conv_cfg,
+ features // 16,
+ num_classes,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ features = self.encoder(x)
+ x_block0, x_block1, x_block2, x_block3, x_block4 = features[
+ 3], features[4], features[5], features[7], features[10]
+ x_d0 = self.conv2(x_block4)
+ x_d1 = self.up1(x_d0, x_block3)
+ x_d2 = self.up2(x_d1, x_block2)
+ x_d3 = self.up3(x_d2, x_block1)
+ x_d4 = self.up4(x_d3, x_block0)
+ out = self.conv3(x_d4)
+ return out
diff --git a/projects/Adabins/configs/_base_/datasets/nyu.py b/projects/Adabins/configs/_base_/datasets/nyu.py
new file mode 100644
index 00000000000..1b49ec7e8de
--- /dev/null
+++ b/projects/Adabins/configs/_base_/datasets/nyu.py
@@ -0,0 +1,32 @@
+dataset_type = 'NYUDataset'
+data_root = 'data/nyu'
+
+test_pipeline = [
+ dict(dict(type='LoadImageFromFile', to_float32=True)),
+ dict(dict(type='LoadDepthAnnotation', depth_rescale_factor=1e-3)),
+ dict(
+ type='PackSegInputs',
+ meta_keys=('img_path', 'depth_map_path', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'category_id'))
+]
+
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ test_mode=True,
+ data_prefix=dict(
+ img_path='images/test', depth_map_path='annotations/test'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+
+val_evaluator = dict(
+ type='DepthMetric', max_depth_eval=10.0, crop_type='nyu_crop')
+test_evaluator = val_evaluator
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
diff --git a/projects/Adabins/configs/_base_/default_runtime.py b/projects/Adabins/configs/_base_/default_runtime.py
new file mode 100644
index 00000000000..272b4d24679
--- /dev/null
+++ b/projects/Adabins/configs/_base_/default_runtime.py
@@ -0,0 +1,15 @@
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+vis_backends = [dict(type='LocalVisBackend')]
+visualizer = dict(
+ type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+log_processor = dict(by_epoch=False)
+log_level = 'INFO'
+load_from = None
+resume = False
+
+tta_model = dict(type='SegTTAModel')
diff --git a/projects/Adabins/configs/_base_/models/Adabins.py b/projects/Adabins/configs/_base_/models/Adabins.py
new file mode 100644
index 00000000000..35cbd8c5777
--- /dev/null
+++ b/projects/Adabins/configs/_base_/models/Adabins.py
@@ -0,0 +1,35 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+data_preprocessor = dict(
+ type='SegDataPreProcessor',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ bgr_to_rgb=True,
+ pad_val=0,
+ seg_pad_val=255)
+model = dict(
+ type='DepthEstimator',
+ data_preprocessor=data_preprocessor,
+ # pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='AdabinsBackbone',
+ basemodel_name='tf_efficientnet_b5_ap',
+ num_features=2048,
+ num_classes=128,
+ bottleneck_features=2048,
+ ),
+ decode_head=dict(
+ type='AdabinsHead',
+ in_channels=128,
+ n_query_channels=128,
+ patch_size=16,
+ embedding_dim=128,
+ num_heads=4,
+ n_bins=256,
+ min_val=0.001,
+ max_val=10,
+ norm='linear'),
+
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/projects/Adabins/configs/adabins/adabins_efficient_b5_4x16_25e_NYU_416x544.py b/projects/Adabins/configs/adabins/adabins_efficient_b5_4x16_25e_NYU_416x544.py
new file mode 100644
index 00000000000..5c00ea152bf
--- /dev/null
+++ b/projects/Adabins/configs/adabins/adabins_efficient_b5_4x16_25e_NYU_416x544.py
@@ -0,0 +1,15 @@
+_base_ = [
+ '../_base_/models/Adabins.py', '../_base_/datasets/nyu.py',
+ '../_base_/default_runtime.py'
+]
+custom_imports = dict(
+ imports=['projects.Adabins.backbones', 'projects.Adabins.decode_head'],
+ allow_failed_imports=False)
+crop_size = (416, 544)
+data_preprocessor = dict(size=crop_size)
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ backbone=dict(),
+ decode_head=dict(),
+)
diff --git a/projects/Adabins/configs/adabins/adabins_efficient_b5_4x16_25e_kitti_352x704.py b/projects/Adabins/configs/adabins/adabins_efficient_b5_4x16_25e_kitti_352x704.py
new file mode 100644
index 00000000000..330cdf41a5b
--- /dev/null
+++ b/projects/Adabins/configs/adabins/adabins_efficient_b5_4x16_25e_kitti_352x704.py
@@ -0,0 +1,12 @@
+_base_ = ['../_base_/models/Adabins.py']
+custom_imports = dict(
+ imports=['projects.Adabins.backbones', 'projects.Adabins.decode_head'],
+ allow_failed_imports=False)
+crop_size = (352, 704)
+data_preprocessor = dict(size=crop_size)
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ backbone=dict(),
+ decode_head=dict(min_val=0.001, max_val=80),
+)
diff --git a/projects/Adabins/decode_head/__init__.py b/projects/Adabins/decode_head/__init__.py
new file mode 100644
index 00000000000..c7d62df12bb
--- /dev/null
+++ b/projects/Adabins/decode_head/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .adabins_head import AdabinsHead
+
+__all__ = ['AdabinsHead']
diff --git a/projects/Adabins/decode_head/adabins_head.py b/projects/Adabins/decode_head/adabins_head.py
new file mode 100644
index 00000000000..ee043172ab9
--- /dev/null
+++ b/projects/Adabins/decode_head/adabins_head.py
@@ -0,0 +1,179 @@
+from typing import List, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import build_conv_layer
+from torch import Tensor
+
+from mmseg.registry import MODELS
+
+
+class PatchTransformerEncoder(nn.Module):
+ """the Patch Transformer Encoder.
+
+ Args:
+ in_channels (int): the channels of input
+ patch_size (int): the path size
+ embedding_dim (int): The feature dimension.
+ num_heads (int): the number of encoder head
+ conv_cfg (dict): Config dict for convolution layer.
+ """
+
+ def __init__(self,
+ in_channels,
+ patch_size=10,
+ embedding_dim=128,
+ num_heads=4,
+ conv_cfg=dict(type='Conv')):
+ super().__init__()
+ encoder_layers = nn.TransformerEncoderLayer(
+ embedding_dim, num_heads, dim_feedforward=1024)
+ self.transformer_encoder = nn.TransformerEncoder(
+ encoder_layers, num_layers=4) # takes shape S,N,E
+
+ self.embedding_convPxP = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ embedding_dim,
+ kernel_size=patch_size,
+ stride=patch_size)
+ self.positional_encodings = nn.Parameter(
+ torch.rand(500, embedding_dim), requires_grad=True)
+
+ def forward(self, x):
+ embeddings = self.embedding_convPxP(x).flatten(
+ 2) # .shape = n,c,s = n, embedding_dim, s
+ embeddings = embeddings + self.positional_encodings[:embeddings.shape[
+ 2], :].T.unsqueeze(0)
+
+ # change to S,N,E format required by transformer
+ embeddings = embeddings.permute(2, 0, 1)
+ x = self.transformer_encoder(embeddings) # .shape = S, N, E
+ return x
+
+
+class PixelWiseDotProduct(nn.Module):
+ """the pixel wise dot product."""
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, K):
+ n, c, h, w = x.size()
+ _, cout, ck = K.size()
+ assert c == ck, 'Number of channels in x and Embedding dimension ' \
+ '(at dim 2) of K matrix must match'
+ y = torch.matmul(
+ x.view(n, c, h * w).permute(0, 2, 1),
+ K.permute(0, 2, 1)) # .shape = n, hw, cout
+ return y.permute(0, 2, 1).view(n, cout, h, w)
+
+
+@MODELS.register_module()
+class AdabinsHead(nn.Module):
+ """the head of the adabins,include mViT.
+
+ Args:
+ in_channels (int):the channels of the input
+ n_query_channels (int):the channels of the query
+ patch_size (int): the patch size
+ embedding_dim (int):The feature dimension.
+ num_heads (int):the number of head
+ n_bins (int):the number of bins
+ min_val (float): the min width of bin
+ max_val (float): the max width of bin
+ conv_cfg (dict): Config dict for convolution layer.
+ norm (str): the activate method
+ align_corners (bool, optional): Geometrically, we consider the pixels
+ of the input and output as squares rather than points.
+ """
+
+ def __init__(self,
+ in_channels,
+ n_query_channels=128,
+ patch_size=16,
+ embedding_dim=128,
+ num_heads=4,
+ n_bins=100,
+ min_val=0.1,
+ max_val=10,
+ conv_cfg=dict(type='Conv'),
+ norm='linear',
+ align_corners=False,
+ threshold=0):
+ super().__init__()
+ self.out_channels = n_bins
+ self.align_corners = align_corners
+ self.norm = norm
+ self.num_classes = n_bins
+ self.min_val = min_val
+ self.max_val = max_val
+ self.n_query_channels = n_query_channels
+ self.patch_transformer = PatchTransformerEncoder(
+ in_channels, patch_size, embedding_dim, num_heads)
+ self.dot_product_layer = PixelWiseDotProduct()
+ self.threshold = threshold
+ self.conv3x3 = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ embedding_dim,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.regressor = nn.Sequential(
+ nn.Linear(embedding_dim, 256), nn.LeakyReLU(), nn.Linear(256, 256),
+ nn.LeakyReLU(), nn.Linear(256, n_bins))
+ self.conv_out = nn.Sequential(
+ build_conv_layer(conv_cfg, in_channels, n_bins, kernel_size=1),
+ nn.Softmax(dim=1))
+
+ def forward(self, x):
+ # n, c, h, w = x.size()
+ tgt = self.patch_transformer(x.clone()) # .shape = S, N, E
+
+ x = self.conv3x3(x)
+
+ regression_head, queries = tgt[0,
+ ...], tgt[1:self.n_query_channels + 1,
+ ...]
+
+ # Change from S, N, E to N, S, E
+ queries = queries.permute(1, 0, 2)
+ range_attention_maps = self.dot_product_layer(
+ x, queries) # .shape = n, n_query_channels, h, w
+
+ y = self.regressor(regression_head) # .shape = N, dim_out
+ if self.norm == 'linear':
+ y = torch.relu(y)
+ eps = 0.1
+ y = y + eps
+ elif self.norm == 'softmax':
+ return torch.softmax(y, dim=1), range_attention_maps
+ else:
+ y = torch.sigmoid(y)
+ bin_widths_normed = y / y.sum(dim=1, keepdim=True)
+ out = self.conv_out(range_attention_maps)
+
+ bin_widths = (self.max_val -
+ self.min_val) * bin_widths_normed # .shape = N, dim_out
+ bin_widths = F.pad(
+ bin_widths, (1, 0), mode='constant', value=self.min_val)
+ bin_edges = torch.cumsum(bin_widths, dim=1)
+
+ centers = 0.5 * (bin_edges[:, :-1] + bin_edges[:, 1:])
+ n, dim_out = centers.size()
+ centers = centers.view(n, dim_out, 1, 1)
+
+ pred = torch.sum(out * centers, dim=1, keepdim=True)
+ return bin_edges, pred
+
+ def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
+ test_cfg, **kwargs) -> Tensor:
+ """Forward function for testing, only ``pam_cam`` is used."""
+ pred = self.forward(inputs)[-1]
+ final = torch.clamp(pred, self.min_val, self.max_val)
+
+ final[torch.isinf(final)] = self.max_val
+ final[torch.isnan(final)] = self.min_val
+ return final
diff --git a/projects/CAT-Seg/README.md b/projects/CAT-Seg/README.md
new file mode 100644
index 00000000000..890e461ce4c
--- /dev/null
+++ b/projects/CAT-Seg/README.md
@@ -0,0 +1,92 @@
+# CAT-Seg
+
+> [CAT-Seg: Cost Aggregation for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2303.11797)
+
+## Introduction
+
+
+
+Official Repo
+
+Code Snippet
+
+## Abstract
+
+
+
+Existing works on open-vocabulary semantic segmentation have utilized large-scale vision-language models, such as CLIP, to leverage their exceptional open-vocabulary recognition capabilities. However, the problem of transferring these capabilities learned from image-level supervision to the pixel-level task of segmentation and addressing arbitrary unseen categories at inference makes this task challenging. To address these issues, we aim to attentively relate objects within an image to given categories by leveraging relational information among class categories and visual semantics through aggregation, while also adapting the CLIP representations to the pixel-level task. However, we observe that direct optimization of the CLIP embeddings can harm its open-vocabulary capabilities. In this regard, we propose an alternative approach to optimize the imagetext similarity map, i.e. the cost map, using a novel cost aggregation-based method. Our framework, namely CATSeg, achieves state-of-the-art performance across all benchmarks. We provide extensive ablation studies to validate our choices. [Project page](https://ku-cvlab.github.io/CAT-Seg).
+
+
+
+
+
+CAT-Seg model structure
+
+
+## Usage
+
+CAT-Seg model training needs pretrained `CLIP` model. We have implemented `ViT-B` and `ViT-L` based `CLIP` model. To further use `ViT-bigG` or `ViT-H` ones, you need additional dependencies. Please install [open_clip](https://github.com/mlfoundations/open_clip) first. The pretrained `CLIP` model state dicts are loaded from [Huggingface-OpenCLIP](https://huggingface.co/models?library=open_clip). **If you come up with `ConnectionError` when downloading CLIP weights**, you can manually download them from the given repo and use `custom_clip_weights=/path/to/you/folder` of backbone in config file. Related tools are as shown in [requirements/optional.txt](requirements/optional.txt):
+
+```shell
+pip install ftfy==6.0.1
+pip install huggingface-hub
+pip install regex
+```
+
+In addition to the necessary [data preparation](https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md), you also need class texts for clip text encoder. Please download the class text json file first [cls_texts](https://github.com/open-mmlab/mmsegmentation/files/11714914/cls_texts.zip) and arrange the folder as follows:
+
+```none
+mmsegmentation
+├── mmseg
+├── tools
+├── configs
+├── data
+│ ├── VOCdevkit
+│ │ ├── VOC2012
+│ │ ├── VOC2010
+│ │ ├── VOCaug
+│ ├── ade
+│ ├── coco_stuff164k
+│ ├── coco.json
+│ ├── pc59.json
+│ ├── pc459.json
+│ ├── ade150.json
+│ ├── ade847.json
+│ ├── voc20b.json
+│ ├── voc20.json
+```
+
+```shell
+# setup PYTHONPATH
+export PYTHONPATH=`pwd`:$PYTHONPATH
+# run evaluation
+mim test mmsegmentation ${CONFIG} --checkpoint ${CHECKPOINT} --launcher pytorch --gpus=8
+```
+
+## Results and models
+
+### ADE20K-150-ZeroShot
+
+| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | Device | mIoU | mIoU(ms+flip) | config | download |
+| ------- | ------------- | --------- | ------- | -------: | -------------- | ------- | ---- | ------------: | ------------------------------------------------------------------------------------------: | --------------------------------------------------------------------------------------------------------------------------------------------- |
+| CAT-Seg | R-101 & ViT-B | 384x384 | 80000 | - | - | RTX3090 | 27.2 | - | [config](./configs/cat_seg/catseg_vitb-r101_4xb1-warmcoslr2e-4-adamw-80k_ade20k-384x384.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/cat_seg/catseg_vitb-r101_4xb1-warmcoslr2e-4-adamw-80k_ade20k-384x384-54194d72.pth) |
+
+Note:
+
+- All experiments of CAT-Seg are implemented with 4 RTX3090 GPUs, except the last one with pretrained ViT-bigG CLIP model (GPU Memory insufficient, you may need A100).
+- Due to the feature size bottleneck of the CLIP image encoder, the inference and testing can only be done under `slide` mode, the inference time is longer since the test size is much more bigger that training size of `(384, 384)`.
+- The ResNet backbones utilized in CAT-Seg models are standard `ResNet` rather than `ResNetV1c`.
+- The zero-shot segmentation results on PASCAL VOC and ADE20K are from the original paper. Our results are coming soon. We appreatiate your contribution!
+- In additional to zero-shot segmentation performance results, we also provided the evaluation results on the `val2017` set of **COCO-stuff164k** for reference, which is the training dataset of CAT-Seg. The testing was done **without TTA**.
+- The number behind the dataset name is the category number for segmentation evaluation (except training data **COCO-stuff 164k**). **PASCAL VOC-20b** defines the "background" as classes present in **PASCAL-Context-59** but not in **PASCAL VOC-20**.
+
+## Citation
+
+```bibtex
+@inproceedings{cheng2021mask2former,
+ title={CAT-Seg: Cost Aggregation for Open-Vocabulary Semantic Segmentation},
+ author={Seokju Cho and Heeseong Shin and Sunghwan Hong and Seungjun An and Seungjun Lee and Anurag Arnab and Paul Hongsuck Seo and Seungryong Kim},
+ journal={CVPR},
+ year={2023}
+}
+```
diff --git a/projects/CAT-Seg/cat_seg/__init__.py b/projects/CAT-Seg/cat_seg/__init__.py
new file mode 100644
index 00000000000..2c51fbaa2e3
--- /dev/null
+++ b/projects/CAT-Seg/cat_seg/__init__.py
@@ -0,0 +1,2 @@
+from .models import * # noqa: F401,F403
+from .utils import * # noqa: F401,F403
diff --git a/projects/CAT-Seg/cat_seg/models/__init__.py b/projects/CAT-Seg/cat_seg/models/__init__.py
new file mode 100644
index 00000000000..cd0e15d3ec9
--- /dev/null
+++ b/projects/CAT-Seg/cat_seg/models/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .cat_aggregator import (AggregatorLayer, CATSegAggregator,
+ ClassAggregateLayer, SpatialAggregateLayer)
+from .cat_head import CATSegHead
+from .clip_ovseg import CLIPOVCATSeg
+
+__all__ = [
+ 'AggregatorLayer', 'CATSegAggregator', 'ClassAggregateLayer',
+ 'SpatialAggregateLayer', 'CATSegHead', 'CLIPOVCATSeg'
+]
diff --git a/projects/CAT-Seg/cat_seg/models/cat_aggregator.py b/projects/CAT-Seg/cat_seg/models/cat_aggregator.py
new file mode 100644
index 00000000000..a0483fe505b
--- /dev/null
+++ b/projects/CAT-Seg/cat_seg/models/cat_aggregator.py
@@ -0,0 +1,763 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import build_norm_layer
+from mmcv.cnn.bricks.transformer import FFN, build_dropout
+from mmengine.model import BaseModule
+from mmengine.utils import to_2tuple
+
+from mmseg.registry import MODELS
+from ..utils import FullAttention, LinearAttention
+
+
+class AGWindowMSA(BaseModule):
+ """Appearance Guidance Window based multi-head self-attention (W-MSA)
+ module with relative position bias.
+
+ Args:
+ embed_dims (int): Number of input channels.
+ appearance_dims (int): Number of appearance guidance feature channels.
+ num_heads (int): Number of attention heads.
+ window_size (tuple[int]): The height and width of the window.
+ qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
+ Default: True.
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ attn_drop_rate (float, optional): Dropout ratio of attention weight.
+ Default: 0.0
+ proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
+ init_cfg (dict | None, optional): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ appearance_dims,
+ num_heads,
+ window_size,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop_rate=0.,
+ proj_drop_rate=0.,
+ init_cfg=None):
+
+ super().__init__(init_cfg=init_cfg)
+ self.embed_dims = embed_dims
+ self.appearance_dims = appearance_dims
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_embed_dims = embed_dims // num_heads
+ self.scale = qk_scale or head_embed_dims**-0.5
+
+ # About 2x faster than original impl
+ Wh, Ww = self.window_size
+ rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
+ rel_position_index = rel_index_coords + rel_index_coords.T
+ rel_position_index = rel_position_index.flip(1).contiguous()
+ self.register_buffer('relative_position_index', rel_position_index)
+
+ self.qk = nn.Linear(
+ embed_dims + appearance_dims, embed_dims * 2, bias=qkv_bias)
+ self.v = nn.Linear(embed_dims, embed_dims, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop_rate)
+ self.proj = nn.Linear(embed_dims, embed_dims)
+ self.proj_drop = nn.Dropout(proj_drop_rate)
+
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x (tensor): input features with shape of (num_windows*B, N, C),
+ C = embed_dims + appearance_dims.
+ mask (tensor | None, Optional): mask with shape of (num_windows,
+ Wh*Ww, Wh*Ww), value should be between (-inf, 0].
+ """
+ B, N, _ = x.shape
+ qk = self.qk(x).reshape(B, N, 2, self.num_heads,
+ self.embed_dims // self.num_heads).permute(
+ 2, 0, 3, 1,
+ 4) # 2 B NUM_HEADS N embed_dims//NUM_HEADS
+ v = self.v(x[:, :, :self.embed_dims]).reshape(
+ B, N, self.num_heads, self.embed_dims // self.num_heads).permute(
+ 0, 2, 1, 3) # B NUM_HEADS N embed_dims//NUM_HEADS
+ # make torchscript happy (cannot use tensor as tuple)
+ q, k = qk[0], qk[1]
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B // nW, nW, self.num_heads, N,
+ N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, self.embed_dims)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ @staticmethod
+ def double_step_seq(step1, len1, step2, len2):
+ """Double step sequence."""
+ seq1 = torch.arange(0, step1 * len1, step1)
+ seq2 = torch.arange(0, step2 * len2, step2)
+ return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
+
+
+class AGShiftWindowMSA(BaseModule):
+ """Appearance Guidance Shifted Window Multihead Self-Attention Module.
+
+ Args:
+ embed_dims (int): Number of input channels.
+ appearance_dims (int): Number of appearance guidance channels
+ num_heads (int): Number of attention heads.
+ window_size (int): The height and width of the window.
+ shift_size (int, optional): The shift step of each window towards
+ right-bottom. If zero, act as regular window-msa. Defaults to 0.
+ qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
+ Default: True
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Defaults: None.
+ attn_drop_rate (float, optional): Dropout ratio of attention weight.
+ Defaults: 0.
+ proj_drop_rate (float, optional): Dropout ratio of output.
+ Defaults: 0.
+ dropout_layer (dict, optional): The dropout_layer used before output.
+ Defaults: dict(type='DropPath', drop_prob=0.).
+ init_cfg (dict, optional): The extra config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ appearance_dims,
+ num_heads,
+ window_size,
+ shift_size=0,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop_rate=0,
+ proj_drop_rate=0,
+ dropout_layer=dict(type='DropPath', drop_prob=0.),
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+
+ self.window_size = window_size
+ self.shift_size = shift_size
+ assert 0 <= self.shift_size < self.window_size
+
+ self.w_msa = AGWindowMSA(
+ embed_dims=embed_dims,
+ appearance_dims=appearance_dims,
+ num_heads=num_heads,
+ window_size=to_2tuple(window_size),
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop_rate=attn_drop_rate,
+ proj_drop_rate=proj_drop_rate,
+ init_cfg=None)
+
+ self.drop = build_dropout(dropout_layer)
+
+ def forward(self, query, hw_shape):
+ """
+ Args:
+ query: The input query.
+ hw_shape: The shape of the feature height and width.
+ """
+ B, L, C = query.shape
+ H, W = hw_shape
+ assert L == H * W, 'input feature has wrong size'
+ query = query.view(B, H, W, C)
+
+ # pad feature maps to multiples of window size
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
+ H_pad, W_pad = query.shape[1], query.shape[2]
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_query = torch.roll(
+ query,
+ shifts=(-self.shift_size, -self.shift_size),
+ dims=(1, 2))
+
+ # calculate attention mask for SW-MSA
+ img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device)
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size,
+ -self.shift_size), slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size,
+ -self.shift_size), slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ # nW, window_size, window_size, 1
+ mask_windows = self.window_partition(img_mask)
+ mask_windows = mask_windows.view(
+ -1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0,
+ float(-100.0)).masked_fill(
+ attn_mask == 0, float(0.0))
+ else:
+ shifted_query = query
+ attn_mask = None
+
+ # nW*B, window_size, window_size, C
+ query_windows = self.window_partition(shifted_query)
+ # nW*B, window_size*window_size, C
+ query_windows = query_windows.view(-1, self.window_size**2, C)
+
+ # W-MSA/SW-MSA (nW*B, window_size*window_size, C)
+ attn_windows = self.w_msa(query_windows, mask=attn_mask)
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size,
+ self.window_size,
+ self.w_msa.embed_dims)
+
+ # B H' W' self.w_msa.embed_dims
+ shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(
+ shifted_x,
+ shifts=(self.shift_size, self.shift_size),
+ dims=(1, 2))
+ else:
+ x = shifted_x
+
+ if pad_r > 0 or pad_b:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = x.view(B, H * W, self.w_msa.embed_dims)
+
+ x = self.drop(x)
+ return x
+
+ def window_reverse(self, windows, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ window_size = self.window_size
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size,
+ window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+ def window_partition(self, x):
+ """
+ Args:
+ x: (B, H, W, C)
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ window_size = self.window_size
+ x = x.view(B, H // window_size, window_size, W // window_size,
+ window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
+ windows = windows.view(-1, window_size, window_size, C)
+ return windows
+
+
+class AGSwinBlock(BaseModule):
+ """Appearance Guidance Swin Transformer Block.
+
+ Args:
+ embed_dims (int): The feature dimension.
+ appearance_dims (int): The appearance guidance dimension.
+ num_heads (int): Parallel attention heads.
+ mlp_ratios (int): The hidden dimension ratio w.r.t. embed_dims
+ for FFNs.
+ window_size (int, optional): The local window scale.
+ Default: 7.
+ shift (bool, optional): whether to shift window or not.
+ Default False.
+ qkv_bias (bool, optional): enable bias for qkv if True.
+ Default: True.
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ drop_rate (float, optional): Dropout rate. Default: 0.
+ attn_drop_rate (float, optional): Attention dropout rate.
+ Default: 0.
+ drop_path_rate (float, optional): Stochastic depth rate.
+ Default: 0.
+ act_cfg (dict, optional): The config dict of activation function.
+ Default: dict(type='GELU').
+ norm_cfg (dict, optional): The config dict of normalization.
+ Default: dict(type='LN').
+ init_cfg (dict | list | None, optional): The init config.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ appearance_dims,
+ num_heads,
+ mlp_ratios=4,
+ window_size=7,
+ shift=False,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='LN'),
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
+ self.attn = AGShiftWindowMSA(
+ embed_dims=embed_dims,
+ appearance_dims=appearance_dims,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=window_size // 2 if shift else 0,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop_rate=attn_drop_rate,
+ proj_drop_rate=drop_rate,
+ dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
+ init_cfg=None)
+
+ self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
+ self.ffn = FFN(
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * mlp_ratios,
+ num_fcs=2,
+ ffn_drop=drop_rate,
+ dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
+ act_cfg=act_cfg,
+ add_identity=True,
+ init_cfg=None)
+
+ def forward(self, inputs, hw_shape):
+ """
+ Args:
+ inputs (list[Tensor]): appearance_guidance (B, H, W, C);
+ x (B, L, C)
+ hw_shape (tuple[int]): shape of feature.
+ """
+ x, appearance_guidance = inputs
+ B, L, C = x.shape
+ H, W = hw_shape
+ assert L == H * W, 'input feature has wrong size'
+
+ identity = x
+ x = self.norm1(x)
+
+ # appearance guidance
+ x = x.view(B, H, W, C)
+ if appearance_guidance is not None:
+ x = torch.cat([x, appearance_guidance], dim=-1).flatten(1, 2)
+
+ x = self.attn(x, hw_shape)
+
+ x = x + identity
+
+ identity = x
+ x = self.norm2(x)
+ x = self.ffn(x, identity=identity)
+
+ return x
+
+
+@MODELS.register_module()
+class SpatialAggregateLayer(BaseModule):
+ """Spatial aggregation layer of CAT-Seg.
+
+ Args:
+ embed_dims (int): The feature dimension.
+ appearance_dims (int): The appearance guidance dimension.
+ num_heads (int): Parallel attention heads.
+ mlp_ratios (int): The hidden dimension ratio w.r.t. embed_dims
+ for FFNs.
+ window_size (int, optional): The local window scale. Default: 7.
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ init_cfg (dict | list | None, optional): The init config.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ appearance_dims,
+ num_heads,
+ mlp_ratios,
+ window_size=7,
+ qk_scale=None,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.block_1 = AGSwinBlock(
+ embed_dims,
+ appearance_dims,
+ num_heads,
+ mlp_ratios,
+ window_size=window_size,
+ shift=False,
+ qk_scale=qk_scale)
+ self.block_2 = AGSwinBlock(
+ embed_dims,
+ appearance_dims,
+ num_heads,
+ mlp_ratios,
+ window_size=window_size,
+ shift=True,
+ qk_scale=qk_scale)
+ self.guidance_norm = nn.LayerNorm(
+ appearance_dims) if appearance_dims > 0 else None
+
+ def forward(self, x, appearance_guidance):
+ """
+ Args:
+ x (torch.Tensor): B C T H W.
+ appearance_guidance (torch.Tensor): B C H W.
+ """
+ B, C, T, H, W = x.shape
+ x = x.permute(0, 2, 3, 4, 1).flatten(0, 1).flatten(1, 2) # BT, HW, C
+ if appearance_guidance is not None:
+ appearance_guidance = appearance_guidance.repeat(
+ T, 1, 1, 1).permute(0, 2, 3, 1) # BT, HW, C
+ appearance_guidance = self.guidance_norm(appearance_guidance)
+ else:
+ assert self.appearance_dims == 0
+ x = self.block_1((x, appearance_guidance), (H, W))
+ x = self.block_2((x, appearance_guidance), (H, W))
+ x = x.transpose(1, 2).reshape(B, T, C, -1)
+ x = x.transpose(1, 2).reshape(B, C, T, H, W)
+ return x
+
+
+class AttentionLayer(nn.Module):
+ """Attention layer for ClassAggregration of CAT-Seg.
+
+ Source: https://github.com/KU-CVLAB/CAT-Seg/blob/main/cat_seg/modeling/transformer/model.py#L310 # noqa
+ """
+
+ def __init__(self,
+ hidden_dim,
+ guidance_dim,
+ nheads=8,
+ attention_type='linear'):
+ super().__init__()
+ self.nheads = nheads
+ self.q = nn.Linear(hidden_dim + guidance_dim, hidden_dim)
+ self.k = nn.Linear(hidden_dim + guidance_dim, hidden_dim)
+ self.v = nn.Linear(hidden_dim, hidden_dim)
+
+ if attention_type == 'linear':
+ self.attention = LinearAttention()
+ elif attention_type == 'full':
+ self.attention = FullAttention()
+ else:
+ raise NotImplementedError
+
+ def forward(self, x, guidance=None):
+ """
+ Args:
+ x: B*H_p*W_p, T, C
+ guidance: B*H_p*W_p, T, C
+ """
+ B, L, _ = x.shape
+ q = self.q(torch.cat([x, guidance],
+ dim=-1)) if guidance is not None else self.q(x)
+ k = self.k(torch.cat([x, guidance],
+ dim=-1)) if guidance is not None else self.k(x)
+ v = self.v(x)
+
+ q = q.reshape(B, L, self.nheads, -1)
+ k = k.reshape(B, L, self.nheads, -1)
+ v = v.reshape(B, L, self.nheads, -1)
+
+ out = self.attention(q, k, v)
+ out = out.reshape(B, L, -1)
+ return out
+
+
+@MODELS.register_module()
+class ClassAggregateLayer(BaseModule):
+ """Class aggregation layer of CAT-Seg.
+
+ Args:
+ hidden_dims (int): The feature dimension.
+ guidance_dims (int): The appearance guidance dimension.
+ num_heads (int): Parallel attention heads.
+ attention_type (str): Type of attention layer. Default: 'linear'.
+ pooling_size (tuple[int] | list[int]): Pooling size.
+ init_cfg (dict | list | None, optional): The init config.
+ Default: None.
+ """
+
+ def __init__(
+ self,
+ hidden_dims=64,
+ guidance_dims=64,
+ num_heads=8,
+ attention_type='linear',
+ pooling_size=(4, 4),
+ init_cfg=None,
+ ):
+ super().__init__(init_cfg=init_cfg)
+ self.pool = nn.AvgPool2d(pooling_size)
+ self.attention = AttentionLayer(
+ hidden_dims,
+ guidance_dims,
+ nheads=num_heads,
+ attention_type=attention_type)
+ self.MLP = FFN(
+ embed_dims=hidden_dims,
+ feedforward_channels=hidden_dims * 4,
+ num_fcs=2)
+ self.norm1 = nn.LayerNorm(hidden_dims)
+ self.norm2 = nn.LayerNorm(hidden_dims)
+
+ def pool_features(self, x):
+ """Intermediate pooling layer for computational efficiency.
+
+ Args:
+ x: B, C, T, H, W
+ """
+ B, C, T, H, W = x.shape
+ x = x.transpose(1, 2).reshape(-1, C, H, W)
+ x = self.pool(x)
+ *_, H_, W_ = x.shape
+ x = x.reshape(B, T, C, H_, W_).transpose(1, 2)
+ return x
+
+ def forward(self, x, guidance):
+ """
+ Args:
+ x: B, C, T, H, W
+ guidance: B, T, C
+ """
+ B, C, T, H, W = x.size()
+ x_pool = self.pool_features(x)
+ *_, H_pool, W_pool = x_pool.size()
+
+ x_pool = x_pool.permute(0, 3, 4, 2, 1).reshape(-1, T, C)
+ # B*H_p*W_p T C
+ if guidance is not None:
+ guidance = guidance.repeat(H_pool * W_pool, 1, 1)
+
+ x_pool = x_pool + self.attention(self.norm1(x_pool),
+ guidance) # Attention
+ x_pool = x_pool + self.MLP(self.norm2(x_pool)) # MLP
+
+ x_pool = x_pool.reshape(B, H_pool * W_pool, T,
+ C).permute(0, 2, 3, 1).reshape(
+ B, T, C, H_pool,
+ W_pool).flatten(0, 1) # BT C H_p W_p
+ x_pool = F.interpolate(
+ x_pool, size=(H, W), mode='bilinear', align_corners=True)
+ x_pool = x_pool.reshape(B, T, C, H, W).transpose(1, 2) # B C T H W
+ x = x + x_pool # Residual
+
+ return x
+
+
+@MODELS.register_module()
+class AggregatorLayer(BaseModule):
+ """Single Aggregator Layer of CAT-Seg."""
+
+ def __init__(self,
+ embed_dims=64,
+ text_guidance_dims=512,
+ appearance_guidance_dims=512,
+ num_heads=4,
+ mlp_ratios=4,
+ window_size=7,
+ attention_type='linear',
+ pooling_size=(2, 2),
+ init_cfg=None) -> None:
+ super().__init__(init_cfg=init_cfg)
+ self.spatial_agg = SpatialAggregateLayer(
+ embed_dims,
+ appearance_guidance_dims,
+ num_heads=num_heads,
+ mlp_ratios=mlp_ratios,
+ window_size=window_size)
+ self.class_agg = ClassAggregateLayer(
+ embed_dims,
+ text_guidance_dims,
+ num_heads=num_heads,
+ attention_type=attention_type,
+ pooling_size=pooling_size)
+
+ def forward(self, x, appearance_guidance, text_guidance):
+ """
+ Args:
+ x: B C T H W
+ """
+ x = self.spatial_agg(x, appearance_guidance)
+ x = self.class_agg(x, text_guidance)
+ return x
+
+
+@MODELS.register_module()
+class CATSegAggregator(BaseModule):
+ """CATSeg Aggregator.
+
+ This Aggregator is the mmseg implementation of
+ `CAT-Seg `_.
+
+ Args:
+ text_guidance_dim (int): Text guidance dimensions. Default: 512.
+ text_guidance_proj_dim (int): Text guidance projection dimensions.
+ Default: 128.
+ appearance_guidance_dim (int): Appearance guidance dimensions.
+ Default: 512.
+ appearance_guidance_proj_dim (int): Appearance guidance projection
+ dimensions. Default: 128.
+ num_layers (int): Aggregator layer number. Default: 4.
+ num_heads (int): Attention layer head number. Default: 4.
+ embed_dims (int): Input feature dimensions. Default: 128.
+ pooling_size (tuple | list): Pooling size of the class aggregator
+ layer. Default: (6, 6).
+ mlp_ratios (int): The hidden dimension ratio w.r.t. input dimension.
+ Default: 4.
+ window_size (int): Swin block window size. Default:12.
+ attention_type (str): Attention type of class aggregator layer.
+ Default:'linear'.
+ prompt_channel (int): Prompt channels. Default: 80.
+ """
+
+ def __init__(self,
+ text_guidance_dim=512,
+ text_guidance_proj_dim=128,
+ appearance_guidance_dim=512,
+ appearance_guidance_proj_dim=128,
+ num_layers=4,
+ num_heads=4,
+ embed_dims=128,
+ pooling_size=(6, 6),
+ mlp_ratios=4,
+ window_size=12,
+ attention_type='linear',
+ prompt_channel=80,
+ **kwargs):
+ super().__init__(**kwargs)
+ self.num_layers = num_layers
+ self.embed_dims = embed_dims
+
+ self.layers = nn.ModuleList([
+ AggregatorLayer(
+ embed_dims=embed_dims,
+ text_guidance_dims=text_guidance_proj_dim,
+ appearance_guidance_dims=appearance_guidance_proj_dim,
+ num_heads=num_heads,
+ mlp_ratios=mlp_ratios,
+ window_size=window_size,
+ attention_type=attention_type,
+ pooling_size=pooling_size) for _ in range(num_layers)
+ ])
+
+ self.conv1 = nn.Conv2d(
+ prompt_channel, embed_dims, kernel_size=7, stride=1, padding=3)
+
+ self.guidance_projection = nn.Sequential(
+ nn.Conv2d(
+ appearance_guidance_dim,
+ appearance_guidance_proj_dim,
+ kernel_size=3,
+ stride=1,
+ padding=1),
+ nn.ReLU(),
+ ) if appearance_guidance_dim > 0 else None
+
+ self.text_guidance_projection = nn.Sequential(
+ nn.Linear(text_guidance_dim, text_guidance_proj_dim),
+ nn.ReLU(),
+ ) if text_guidance_dim > 0 else None
+
+ def feature_map(self, img_feats, text_feats):
+ """Concatenation type cost volume.
+
+ For ablation study of cost volume type.
+ """
+ img_feats = F.normalize(img_feats, dim=1) # B C H W
+ img_feats = img_feats.unsqueeze(2).repeat(1, 1, text_feats.shape[1], 1,
+ 1)
+ text_feats = F.normalize(text_feats, dim=-1) # B T P C
+ text_feats = text_feats.mean(dim=-2)
+ text_feats = F.normalize(text_feats, dim=-1) # B T C
+ text_feats = text_feats.unsqueeze(-1).unsqueeze(-1).repeat(
+ 1, 1, 1, img_feats.shape[-2], img_feats.shape[-1]).transpose(1, 2)
+ return torch.cat((img_feats, text_feats), dim=1) # B 2C T H W
+
+ def correlation(self, img_feats, text_feats):
+ """Correlation of image features and text features."""
+ img_feats = F.normalize(img_feats, dim=1) # B C H W
+ text_feats = F.normalize(text_feats, dim=-1) # B T P C
+ corr = torch.einsum('bchw, btpc -> bpthw', img_feats, text_feats)
+ return corr
+
+ def corr_embed(self, x):
+ """Correlation embeddings encoding."""
+ B = x.shape[0]
+ corr_embed = x.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ corr_embed = self.conv1(corr_embed)
+ corr_embed = corr_embed.reshape(B, -1, self.embed_dims, x.shape[-2],
+ x.shape[-1]).transpose(1, 2)
+ return corr_embed
+
+ def forward(self, inputs):
+ """
+ Args:
+ inputs (dict): including the following keys,
+ 'appearance_feat': list[torch.Tensor], w.r.t. out_indices of
+ `self.feature_extractor`.
+ 'clip_text_feat': the text feature extracted by clip text
+ encoder.
+ 'clip_text_feat_test': the text feature extracted by clip text
+ encoder for testing.
+ 'clip_img_feat': the image feature extracted clip image
+ encoder.
+ """
+ img_feats = inputs['clip_img_feat']
+ B = img_feats.size(0)
+ appearance_guidance = inputs[
+ 'appearance_feat'][::-1] # order (out_indices) 2, 1, 0
+ text_feats = inputs['clip_text_feat'] if self.training else inputs[
+ 'clip_text_feat_test']
+ text_feats = text_feats.repeat(B, 1, 1, 1)
+
+ corr = self.correlation(img_feats, text_feats)
+ # corr = self.feature_map(img_feats, text_feats)
+ corr_embed = self.corr_embed(corr)
+
+ projected_guidance, projected_text_guidance = None, None
+
+ if self.guidance_projection is not None:
+ projected_guidance = self.guidance_projection(
+ appearance_guidance[0])
+
+ if self.text_guidance_projection is not None:
+ text_feats = text_feats.mean(dim=-2)
+ text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)
+ projected_text_guidance = self.text_guidance_projection(text_feats)
+
+ for layer in self.layers:
+ corr_embed = layer(corr_embed, projected_guidance,
+ projected_text_guidance)
+
+ return dict(
+ corr_embed=corr_embed, appearance_feats=appearance_guidance[1:])
diff --git a/projects/CAT-Seg/cat_seg/models/cat_head.py b/projects/CAT-Seg/cat_seg/models/cat_head.py
new file mode 100644
index 00000000000..36bb1c56179
--- /dev/null
+++ b/projects/CAT-Seg/cat_seg/models/cat_head.py
@@ -0,0 +1,116 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+
+from mmseg.models.decode_heads.decode_head import BaseDecodeHead
+from mmseg.registry import MODELS
+
+
+class UpBlock(nn.Module):
+ """Upsample Block with two consecutive convolution layers."""
+
+ def __init__(self, in_channels, out_channels, guidance_channels):
+ super().__init__()
+ self.up = nn.ConvTranspose2d(
+ in_channels,
+ in_channels - guidance_channels,
+ kernel_size=2,
+ stride=2)
+ self.conv1 = ConvModule(
+ in_channels,
+ out_channels,
+ 3,
+ padding=1,
+ bias=False,
+ norm_cfg=dict(type='GN', num_groups=out_channels // 16))
+ self.conv2 = ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ padding=1,
+ bias=False,
+ norm_cfg=dict(type='GN', num_groups=out_channels // 16))
+
+ def forward(self, x, guidance=None):
+ """Forward function with visual guidance."""
+ x = self.up(x)
+ if guidance is not None:
+ T = x.size(0) // guidance.size(0)
+ # guidance = repeat(guidance, "B C H W -> (B T) C H W", T=T)
+ guidance = guidance.repeat(T, 1, 1, 1)
+ x = torch.cat([x, guidance], dim=1)
+ x = self.conv1(x)
+
+ return self.conv2(x)
+
+
+@MODELS.register_module()
+class CATSegHead(BaseDecodeHead):
+ """CATSeg Head.
+
+ This segmentation head is the mmseg implementation of
+ `CAT-Seg `_.
+
+ Args:
+ embed_dims (int): The number of input dimensions.
+ decoder_dims (list): The number of decoder dimensions.
+ decoder_guidance_proj_dims (list): The number of appearance
+ guidance dimensions.
+ init_cfg
+ """
+
+ def __init__(self,
+ embed_dims=128,
+ decoder_dims=(64, 32),
+ decoder_guidance_dims=(256, 128),
+ decoder_guidance_proj_dims=(32, 16),
+ **kwargs):
+ super().__init__(**kwargs)
+ self.decoder_guidance_projection = nn.ModuleList([
+ nn.Sequential(
+ nn.Conv2d(
+ dec_dims,
+ dec_dims_proj,
+ kernel_size=3,
+ stride=1,
+ padding=1),
+ nn.ReLU(),
+ ) for dec_dims, dec_dims_proj in zip(decoder_guidance_dims,
+ decoder_guidance_proj_dims)
+ ]) if decoder_guidance_dims[0] > 0 else None
+
+ self.decoder1 = UpBlock(embed_dims, decoder_dims[0],
+ decoder_guidance_proj_dims[0])
+ self.decoder2 = UpBlock(decoder_dims[0], decoder_dims[1],
+ decoder_guidance_proj_dims[1])
+ self.conv_seg = nn.Conv2d(
+ decoder_dims[1], 1, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, inputs):
+ """Forward function.
+
+ Args:
+ inputs (dict): Input features including the following features,
+ corr_embed: aggregated correlation embeddings.
+ appearance_feats: decoder appearance feature guidance.
+ """
+ # decoder guidance projection
+ if self.decoder_guidance_projection is not None:
+ projected_decoder_guidance = [
+ proj(g) for proj, g in zip(self.decoder_guidance_projection,
+ inputs['appearance_feats'])
+ ]
+
+ # decoder layers
+ B = inputs['corr_embed'].size(0)
+ corr_embed = inputs['corr_embed'].transpose(1, 2).flatten(0, 1)
+ corr_embed = self.decoder1(corr_embed, projected_decoder_guidance[0])
+ corr_embed = self.decoder2(corr_embed, projected_decoder_guidance[1])
+
+ output = self.cls_seg(corr_embed)
+
+ # rearrange the output to (B, T, H, W)
+ H_ori, W_ori = output.shape[-2:]
+ output = output.reshape(B, -1, H_ori, W_ori)
+ return output
diff --git a/projects/CAT-Seg/cat_seg/models/clip_ovseg.py b/projects/CAT-Seg/cat_seg/models/clip_ovseg.py
new file mode 100644
index 00000000000..cb67744e34a
--- /dev/null
+++ b/projects/CAT-Seg/cat_seg/models/clip_ovseg.py
@@ -0,0 +1,293 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+import os
+from typing import List
+
+import torch
+import torch.nn.functional as F
+from huggingface_hub.utils._errors import LocalEntryNotFoundError
+from mmengine.model import BaseModule
+
+from mmseg.registry import MODELS
+from mmseg.utils import ConfigType
+from ..utils import clip_wrapper
+from ..utils.clip_templates import (IMAGENET_TEMPLATES,
+ IMAGENET_TEMPLATES_SELECT)
+
+
+@MODELS.register_module()
+class CLIPOVCATSeg(BaseModule):
+ """CLIP based Open Vocabulary CAT-Seg model backbone.
+
+ This backbone is the modified implementation of `CAT-Seg Backbone
+ `_. It combines the CLIP model and
+ another feature extractor, a.k.a the appearance guidance extractor
+ in the original `CAT-Seg`.
+
+ Args:
+ feature_extractor (ConfigType): Appearance guidance extractor
+ config dict.
+ train_class_json (str): The training class json file.
+ test_class_json (str): The path to test class json file.
+ clip_pretrained (str): The pre-trained clip type.
+ clip_finetune (str): The finetuning settings of clip model.
+ custom_clip_weights (str): The custmized clip weights directory. When
+ encountering huggingface model download errors, you can manually
+ download the pretrained weights.
+ backbone_multiplier (float): The learning rate multiplier.
+ Default: 0.01.
+ prompt_depth (int): The prompt depth. Default: 0.
+ prompt_length (int): The prompt length. Default: 0.
+ prompt_ensemble_type (str): The prompt ensemble type.
+ Default: "imagenet".
+ pixel_mean (List[float]): The pixel mean for feature extractor.
+ pxiel_std (List[float]): The pixel std for feature extractor.
+ clip_pixel_mean (List[float]): The pixel mean for clip model.
+ clip_pxiel_std (List[float]): The pixel std for clip model.
+ clip_img_feat_size: (List[int]: Clip image embedding size from
+ image encoder.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ """
+
+ def __init__(
+ self,
+ feature_extractor: ConfigType,
+ train_class_json: str,
+ test_class_json: str,
+ clip_pretrained: str,
+ clip_finetune: str,
+ custom_clip_weights: str = None,
+ backbone_multiplier=0.01,
+ prompt_depth: int = 0,
+ prompt_length: int = 0,
+ prompt_ensemble_type: str = 'imagenet',
+ pixel_mean: List[float] = [123.675, 116.280, 103.530],
+ pixel_std: List[float] = [58.395, 57.120, 57.375],
+ clip_pixel_mean: List[float] = [
+ 122.7709383, 116.7460125, 104.09373615
+ ],
+ clip_pixel_std: List[float] = [68.5005327, 66.6321579, 70.3231630],
+ clip_img_feat_size: List[int] = [24, 24],
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ # normalization parameters
+ self.register_buffer('pixel_mean',
+ torch.Tensor(pixel_mean).view(1, -1, 1, 1), False)
+ self.register_buffer('pixel_std',
+ torch.Tensor(pixel_std).view(1, -1, 1, 1), False)
+ self.register_buffer('clip_pixel_mean',
+ torch.Tensor(clip_pixel_mean).view(1, -1, 1, 1),
+ False)
+ self.register_buffer('clip_pixel_std',
+ torch.Tensor(clip_pixel_std).view(1, -1, 1, 1),
+ False)
+ self.clip_resolution = (
+ 384, 384) if clip_pretrained == 'ViT-B/16' else (336, 336)
+ # modified clip image encoder with fixed size dense output
+ self.clip_img_feat_size = clip_img_feat_size
+
+ # prepare clip templates
+ self.prompt_ensemble_type = prompt_ensemble_type
+ if self.prompt_ensemble_type == 'imagenet_select':
+ prompt_templates = IMAGENET_TEMPLATES_SELECT
+ elif self.prompt_ensemble_type == 'imagenet':
+ prompt_templates = IMAGENET_TEMPLATES
+ elif self.prompt_ensemble_type == 'single':
+ prompt_templates = [
+ 'A photo of a {} in the scene',
+ ]
+ else:
+ raise NotImplementedError
+ self.prompt_templates = prompt_templates
+
+ # build the feature extractor
+ self.feature_extractor = MODELS.build(feature_extractor)
+
+ # build CLIP model
+ with open(train_class_json) as f_in:
+ self.class_texts = json.load(f_in)
+ with open(test_class_json) as f_in:
+ self.test_class_texts = json.load(f_in)
+ assert self.class_texts is not None
+ if self.test_class_texts is None:
+ self.test_class_texts = self.class_texts
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ self.tokenizer = None
+ if clip_pretrained == 'ViT-G' or clip_pretrained == 'ViT-H':
+ # for OpenCLIP models
+ import open_clip
+ name, pretrain = (
+ 'ViT-H-14',
+ 'laion2b_s32b_b79k') if clip_pretrained == 'ViT-H' else (
+ 'ViT-bigG-14', 'laion2b_s39b_b160k')
+ try:
+ open_clip_model = open_clip.create_model_and_transforms(
+ name,
+ pretrained=pretrain,
+ device=device,
+ force_image_size=336,
+ )
+ clip_model, _, clip_preprocess = open_clip_model
+ except ConnectionError or LocalEntryNotFoundError as e:
+ print(f'Has {e} when loading weights from huggingface!')
+ print(
+ f'Will load {pretrain} weights from {custom_clip_weights}.'
+ )
+ assert custom_clip_weights is not None, 'Please specify custom weights directory.' # noqa
+ assert os.path.exists(
+ os.path.join(custom_clip_weights,
+ 'open_clip_pytorch_model.bin')
+ ), 'Please provide a valid directory for manually downloaded model.' # noqa
+ open_clip_model = open_clip.create_model_and_transforms(
+ name,
+ pretrained=None,
+ device='cpu',
+ force_image_size=336,
+ )
+ clip_model, _, clip_preprocess = open_clip_model
+
+ open_clip.load_checkpoint(
+ clip_model,
+ os.path.expanduser(
+ os.path.join(custom_clip_weights,
+ 'open_clip_pytorch_model.bin')))
+ clip_model.to(torch.device(device))
+
+ self.tokenizer = open_clip.get_tokenizer(name)
+ else:
+ # for OpenAI models
+ clip_model, clip_preprocess = clip_wrapper.load(
+ clip_pretrained,
+ device=device,
+ jit=False,
+ prompt_depth=prompt_depth,
+ prompt_length=prompt_length)
+
+ # pre-encode classes text prompts
+ text_features = self.class_embeddings(self.class_texts,
+ prompt_templates, clip_model,
+ device).permute(1, 0, 2).float()
+ text_features_test = self.class_embeddings(self.test_class_texts,
+ prompt_templates,
+ clip_model,
+ device).permute(1, 0,
+ 2).float()
+ self.register_buffer('text_features', text_features, False)
+ self.register_buffer('text_features_test', text_features_test, False)
+
+ # prepare CLIP model finetune
+ self.clip_finetune = clip_finetune
+ self.clip_model = clip_model.float()
+ self.clip_preprocess = clip_preprocess
+
+ for name, params in self.clip_model.named_parameters():
+ if 'visual' in name:
+ if clip_finetune == 'prompt':
+ params.requires_grad = True if 'prompt' in name else False
+ elif clip_finetune == 'attention':
+ if 'attn' in name or 'position' in name:
+ params.requires_grad = True
+ else:
+ params.requires_grad = False
+ elif clip_finetune == 'full':
+ params.requires_grad = True
+ else:
+ params.requires_grad = False
+ else:
+ params.requires_grad = False
+
+ finetune_backbone = backbone_multiplier > 0.
+ for name, params in self.feature_extractor.named_parameters():
+ if 'norm0' in name:
+ params.requires_grad = False
+ else:
+ params.requires_grad = finetune_backbone
+
+ @torch.no_grad()
+ def class_embeddings(self,
+ classnames,
+ templates,
+ clip_model,
+ device='cpu'):
+ """Convert class names to text embeddings by clip model.
+
+ Args:
+ classnames (list): loaded from json file.
+ templates (dict): text template.
+ clip_model (nn.Module): prepared clip model.
+ device (str | torch.device): loading device of text
+ encoder results.
+ """
+ zeroshot_weights = []
+ for classname in classnames:
+ if ', ' in classname:
+ classname_splits = classname.split(', ')
+ texts = []
+ for template in templates:
+ for cls_split in classname_splits:
+ texts.append(template.format(cls_split))
+ else:
+ texts = [template.format(classname)
+ for template in templates] # format with class
+ if self.tokenizer is not None:
+ texts = self.tokenizer(texts).to(device)
+ else:
+ texts = clip_wrapper.tokenize(texts).to(device)
+ class_embeddings = clip_model.encode_text(texts)
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
+ if len(templates) != class_embeddings.shape[0]:
+ class_embeddings = class_embeddings.reshape(
+ len(templates), -1, class_embeddings.shape[-1]).mean(dim=1)
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
+ class_embedding = class_embeddings
+ zeroshot_weights.append(class_embedding)
+ zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
+ return zeroshot_weights
+
+ def custom_normalize(self, inputs):
+ """Input normalization for clip model and feature extractor
+ respectively.
+
+ Args:
+ inputs: batched input images.
+ """
+ # clip images
+ batched_clip = (inputs - self.clip_pixel_mean) / self.clip_pixel_std
+ batched_clip = F.interpolate(
+ batched_clip,
+ size=self.clip_resolution,
+ mode='bilinear',
+ align_corners=False)
+ # feature extractor images
+ batched = (inputs - self.pixel_mean) / self.pixel_std
+ return batched, batched_clip
+
+ def forward(self, inputs):
+ """
+ Args:
+ inputs: minibatch image. (B, 3, H, W)
+ Returns:
+ outputs (dict):
+ 'appearance_feat': list[torch.Tensor], w.r.t. out_indices of
+ `self.feature_extractor`.
+ 'clip_text_feat': the text feature extracted by clip text encoder.
+ 'clip_text_feat_test': the text feature extracted by clip text
+ encoder for testing.
+ 'clip_img_feat': the image feature extracted clip image encoder.
+ """
+ inputs, clip_inputs = self.custom_normalize(inputs)
+ outputs = dict()
+ # extract appearance guidance feature
+ outputs['appearance_feat'] = self.feature_extractor(inputs)
+
+ # extract clip features
+ outputs['clip_text_feat'] = self.text_features
+ outputs['clip_text_feat_test'] = self.text_features_test
+ clip_features = self.clip_model.encode_image(
+ clip_inputs, dense=True) # B, 577(24x24+1), C
+ B = clip_features.size(0)
+ outputs['clip_img_feat'] = clip_features[:, 1:, :].permute(
+ 0, 2, 1).reshape(B, -1, *self.clip_img_feat_size)
+
+ return outputs
diff --git a/projects/CAT-Seg/cat_seg/utils/__init__.py b/projects/CAT-Seg/cat_seg/utils/__init__.py
new file mode 100644
index 00000000000..88746b2cba6
--- /dev/null
+++ b/projects/CAT-Seg/cat_seg/utils/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .clip_templates import (IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT,
+ IMAGENET_TEMPLATES_SELECT_CLIP, ViLD_templates)
+from .self_attention_block import FullAttention, LinearAttention
+
+__all__ = [
+ 'FullAttention', 'LinearAttention', 'IMAGENET_TEMPLATES',
+ 'IMAGENET_TEMPLATES_SELECT', 'IMAGENET_TEMPLATES_SELECT_CLIP',
+ 'ViLD_templates'
+]
diff --git a/projects/CAT-Seg/cat_seg/utils/bpe_vocab/bpe_simple_vocab_16e6.txt.gz b/projects/CAT-Seg/cat_seg/utils/bpe_vocab/bpe_simple_vocab_16e6.txt.gz
new file mode 100644
index 00000000000..7b5088a527f
Binary files /dev/null and b/projects/CAT-Seg/cat_seg/utils/bpe_vocab/bpe_simple_vocab_16e6.txt.gz differ
diff --git a/projects/CAT-Seg/cat_seg/utils/clip_model.py b/projects/CAT-Seg/cat_seg/utils/clip_model.py
new file mode 100644
index 00000000000..977444f5b52
--- /dev/null
+++ b/projects/CAT-Seg/cat_seg/utils/clip_model.py
@@ -0,0 +1,651 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import OrderedDict
+from typing import Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class Bottleneck(nn.Module):
+ """Custom implementation of Bottleneck in ResNet."""
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1):
+ super().__init__()
+ # all conv layers have stride 1.
+ # an avgpool is performed after the second convolution when stride > 1
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = None
+ self.stride = stride
+
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
+ # downsampling layer is prepended with an avgpool,
+ # and the subsequent convolution has stride 1
+ self.downsample = nn.Sequential(
+ OrderedDict([('-1', nn.AvgPool2d(stride)),
+ ('0',
+ nn.Conv2d(
+ inplanes,
+ planes * self.expansion,
+ 1,
+ stride=1,
+ bias=False)),
+ ('1', nn.BatchNorm2d(planes * self.expansion))]))
+
+ def forward(self, x: torch.Tensor):
+ """
+ Args:
+ x (torch.Tensor): the input feature.
+ """
+ identity = x
+
+ out = self.relu(self.bn1(self.conv1(x)))
+ out = self.relu(self.bn2(self.conv2(out)))
+ out = self.avgpool(out)
+ out = self.bn3(self.conv3(out))
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+ return out
+
+
+class AttentionPool2d(nn.Module):
+ """Attention Pool2d."""
+
+ def __init__(self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads: int,
+ output_dim: int = None):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(
+ torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ """
+ Args:
+ x (torch.Tensor): the input feature.
+ """
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x[:1],
+ key=x,
+ value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat(
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False)
+ return x.squeeze(0)
+
+
+class ModifiedResNet(nn.Module):
+ """A ResNet class that is similar to torchvision's but contains the
+ following changes:
+
+ - There are now 3 "stem" convolutions as opposed to 1, with an average
+ pool instead of a max pool.
+ - Performs anti-aliasing strided convolutions, where an avgpool is
+ prepended to convolutions with stride > 1
+ - The final pooling layer is a QKV attention instead of an average pool
+ """
+
+ def __init__(self,
+ layers,
+ output_dim,
+ heads,
+ input_resolution=224,
+ width=64):
+ super().__init__()
+ self.output_dim = output_dim
+ self.input_resolution = input_resolution
+
+ # the 3-layer stem
+ self.conv1 = nn.Conv2d(
+ 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(width // 2)
+ self.relu1 = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv2d(
+ width // 2, width // 2, kernel_size=3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(width // 2)
+ self.relu2 = nn.ReLU(inplace=True)
+ self.conv3 = nn.Conv2d(
+ width // 2, width, kernel_size=3, padding=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(width)
+ self.relu3 = nn.ReLU(inplace=True)
+ self.avgpool = nn.AvgPool2d(2)
+
+ # residual layers
+ # this is a *mutable* variable used during construction
+ self._inplanes = width
+ self.layer1 = self._make_layer(width, layers[0])
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+ embed_dim = width * 32 # the ResNet feature dimension
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim,
+ heads, output_dim)
+
+ def _make_layer(self, planes, blocks, stride=1):
+ """Build resnet layers."""
+ layers = [Bottleneck(self._inplanes, planes, stride)]
+
+ self._inplanes = planes * Bottleneck.expansion
+ for _ in range(1, blocks):
+ layers.append(Bottleneck(self._inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ """
+ Args:
+ x (torch.Tensor): the input mini-batch images.
+ """
+
+ def stem(x):
+ x = self.relu1(self.bn1(self.conv1(x)))
+ x = self.relu2(self.bn2(self.conv2(x)))
+ x = self.relu3(self.bn3(self.conv3(x)))
+ x = self.avgpool(x)
+ return x
+
+ x = x.type(self.conv1.weight.dtype)
+ x = stem(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.attnpool(x)
+
+ return x
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ """
+ Args:
+ x (torch.Tensor): the input feature.
+ """
+ orig_type = x.dtype
+ ret = super().forward(x.type(torch.float32))
+ return ret.type(orig_type)
+
+
+class QuickGELU(nn.Module):
+ """Wrapper of GELU activation layer."""
+
+ def forward(self, x: torch.Tensor):
+ """
+ Args:
+ x (torch.Tensor): the input feature.
+ """
+ return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+ """Attention block with residual connection."""
+
+ def __init__(self,
+ d_model: int,
+ n_head: int,
+ attn_mask: torch.Tensor = None):
+ super().__init__()
+
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = LayerNorm(d_model)
+ self.mlp = nn.Sequential(
+ OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
+ ('gelu', QuickGELU()),
+ ('c_proj', nn.Linear(d_model * 4, d_model))]))
+ self.ln_2 = LayerNorm(d_model)
+ self.attn_mask = attn_mask
+ self.mask_pre_mlp = True
+
+ def attention(self, x: torch.Tensor):
+ """Calculate mask multi-head-attention."""
+ self.attn_mask = self.attn_mask.to(
+ dtype=x.dtype,
+ device=x.device) if self.attn_mask is not None else None
+ return self.attn(
+ x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+ def forward(self, x: torch.Tensor):
+ """
+ Args:
+ x (torch.Tensor): the input feature.
+ """
+ x = x + self.attention(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+ def forward_dense(self, x: torch.Tensor):
+ """Reinplementation of forward function for dense prediction of image
+ encoder in CLIP model.
+
+ Args:
+ x (torch.Tensor): the input feature.
+ """
+ y = self.ln_1(x)
+ y = F.linear(y, self.attn.in_proj_weight, self.attn.in_proj_bias)
+ L, N, D = y.shape # L N 3D
+
+ y = y.reshape(L, N, 3, D // 3).permute(2, 1, 0,
+ 3).reshape(3 * N, L, D // 3)
+ y = F.linear(y, self.attn.out_proj.weight, self.attn.out_proj.bias)
+
+ q, k, v = y.tensor_split(3, dim=0)
+ v = v.transpose(1, 0) + x # L N D
+
+ v = v + self.mlp(self.ln_2(v))
+ return v
+
+
+class Transformer(nn.Module):
+ """General Transformer Architecture for both image and text encoder."""
+
+ def __init__(self,
+ width: int,
+ layers: int,
+ heads: int,
+ attn_mask: torch.Tensor = None,
+ prompt_length=0,
+ prompt_depth=0):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.Sequential(*[
+ ResidualAttentionBlock(width, heads, attn_mask)
+ for _ in range(layers)
+ ])
+
+ self.prompt_length = prompt_length
+ self.prompt_depth = prompt_depth
+ self.prompt_tokens = nn.Parameter(
+ torch.zeros(prompt_depth, prompt_length,
+ width)) if prompt_length > 0 else None
+ if self.prompt_tokens is not None:
+ nn.init.xavier_uniform_(self.prompt_tokens)
+
+ def forward(self, x: torch.Tensor, dense=False):
+ """
+ Args:
+ x (torch.Tensor): input features.
+ dense (bool): whether use reimplemented dense forward
+ function in the last layer.
+ """
+ for i, resblock in enumerate(self.resblocks):
+ if self.prompt_length > 0 and i < self.prompt_depth:
+ length = self.prompt_length + 1 if i > 0 else 1
+ x = torch.cat((x[0:1, :, :], self.prompt_tokens[i].repeat(
+ x.shape[1], 1, 1).permute(1, 0, 2), x[length:, :, :]))
+
+ if i == self.layers - 1 and dense:
+ x = resblock.forward_dense(x)
+ x = torch.cat((x[0:1, :, :], x[self.prompt_length + 1::, :]),
+ dim=0)
+ else:
+ x = resblock(x)
+
+ return x
+
+
+class VisualTransformer(nn.Module):
+ """Visual encoder for CLIP model."""
+
+ def __init__(self, input_resolution: int, patch_size: int, width: int,
+ layers: int, heads: int, output_dim: int, prompt_depth: int,
+ prompt_length: int):
+ super().__init__()
+ self.output_dim = output_dim
+ self.conv1 = nn.Conv2d(
+ in_channels=3,
+ out_channels=width,
+ kernel_size=patch_size,
+ stride=patch_size,
+ bias=False)
+
+ scale = width**-0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
+ self.positional_embedding = nn.Parameter(scale * torch.randn(
+ (input_resolution // patch_size)**2 + 1, width))
+ self.ln_pre = LayerNorm(width)
+
+ self.transformer = Transformer(
+ width,
+ layers,
+ heads,
+ prompt_depth=prompt_depth,
+ prompt_length=prompt_length)
+
+ self.ln_post = LayerNorm(width)
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+ self.patch_size = patch_size
+ self.input_resolution = input_resolution
+
+ def forward(self, x: torch.Tensor, dense=False):
+ """
+ Args:
+ x (torch.Tensor): input features.
+ dense (bool): whether use reimplemented dense forward
+ function in the last layer.
+ """
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1],
+ -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+ x = torch.cat([
+ self.class_embedding.to(x.dtype) + torch.zeros(
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
+ ],
+ dim=1) # shape = [*, grid ** 2 + 1, width]
+
+ if dense and (x.shape[1] != self.positional_embedding.shape[0]):
+ x = x + self.resized_pos_embed(self.input_resolution,
+ x.shape[1]).to(x.dtype)
+ else:
+ x = x + self.positional_embedding.to(x.dtype)
+
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x, dense)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ if dense:
+ x = self.ln_post(x[:, :, :])
+ else:
+ x = self.ln_post(x[:, 0, :])
+
+ if self.proj is not None:
+ x = x @ self.proj
+
+ return x
+
+ def resized_pos_embed(self, in_res, tgt_res, mode='bicubic'):
+ """Resize the position embedding."""
+ # assert L == (input_resolution // self.patch_size) ** 2 + 1
+ L, D = self.positional_embedding.shape
+
+ in_side = in_res // self.patch_size
+ # tgt_side = tgt_res // self.patch_size
+ tgt_side = int((tgt_res - 1)**0.5)
+
+ cls_pos = self.positional_embedding[0].unsqueeze(0) # 1 D
+ pos_embed = self.positional_embedding[1:].reshape(
+ 1, in_side, in_side, D).permute(0, 3, 1, 2) # L-1 D -> 1 D S S
+ resized_pos_embed = F.interpolate(
+ pos_embed,
+ size=(tgt_side, tgt_side),
+ mode=mode,
+ align_corners=False,
+ ) # 1 D S S -> 1 D S' S'
+ resized_pos_embed = resized_pos_embed.squeeze(0).reshape(
+ D, -1).T # L'-1 D
+
+ return torch.cat((cls_pos, resized_pos_embed), dim=0)
+
+
+class CLIP(nn.Module):
+ """Custom implementation of CLIP model.
+
+ Refer to: https://github.com/openai/CLIP
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ # vision
+ image_resolution: int,
+ vision_layers: Union[Tuple[int, int, int, int], int],
+ vision_width: int,
+ vision_patch_size: int,
+ # text
+ context_length: int,
+ vocab_size: int,
+ transformer_width: int,
+ transformer_heads: int,
+ transformer_layers: int,
+ # prompt
+ prompt_depth: int = 0,
+ prompt_length: int = 0,
+ ):
+ super().__init__()
+
+ self.context_length = context_length
+
+ self.image_resolution = image_resolution
+
+ if isinstance(vision_layers, (tuple, list)):
+ assert prompt_length == 0 and prompt_depth == 0
+ vision_heads = vision_width * 32 // 64
+ self.visual = ModifiedResNet(
+ layers=vision_layers,
+ output_dim=embed_dim,
+ heads=vision_heads,
+ input_resolution=image_resolution,
+ width=vision_width)
+ else:
+ vision_heads = vision_width // 64
+ self.visual = VisualTransformer(
+ input_resolution=image_resolution,
+ patch_size=vision_patch_size,
+ width=vision_width,
+ layers=vision_layers,
+ heads=vision_heads,
+ output_dim=embed_dim,
+ prompt_depth=prompt_depth,
+ prompt_length=prompt_length,
+ )
+
+ self.transformer = Transformer(
+ width=transformer_width,
+ layers=transformer_layers,
+ heads=transformer_heads,
+ attn_mask=self.build_attention_mask())
+
+ self.vocab_size = vocab_size
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
+ self.positional_embedding = nn.Parameter(
+ torch.empty(self.context_length, transformer_width))
+ self.ln_final = LayerNorm(transformer_width)
+
+ self.text_projection = nn.Parameter(
+ torch.empty(transformer_width, embed_dim))
+ self.logit_scale = nn.Parameter(torch.ones([]))
+
+ def build_attention_mask(self):
+ """Create causal attention mask."""
+ # lazily create causal attention mask, with full attention between
+ # the vision tokens pytorch uses additive attention mask; fill with
+ # -inf
+ mask = torch.empty(self.context_length, self.context_length)
+ mask.fill_(float('-inf'))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ @property
+ def dtype(self):
+ """Return the dtype of the model."""
+ return self.visual.conv1.weight.dtype
+
+ def encode_image(self, image, masks=None, pool_mask=None, dense=False):
+ """Image encoding."""
+ if pool_mask is not None:
+ return self.visual(
+ image.type(self.dtype), mask=pool_mask, dense=dense)
+ if masks is None:
+ return self.visual(image.type(self.dtype), dense=dense)
+ else:
+ return self.visual(image.type(self.dtype), masks.type(self.dtype))
+
+ def encode_text(self, text):
+ """Texts encoding."""
+ x = self.token_embedding(text).type(
+ self.dtype) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding.type(self.dtype)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x).type(self.dtype)
+
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number
+ # in each sequence)
+ x = x[torch.arange(x.shape[0]),
+ text.argmax(dim=-1)] @ self.text_projection
+
+ return x
+
+ def forward(self, image, text):
+ """
+ Args:
+ image (torch.Tensor): input images.
+ text (torch.Tensor): input text.
+ """
+ image_features = self.encode_image(image)
+ text_features = self.encode_text(text)
+ # import pdb; pdb.set_trace()
+ # normalized features
+ # image_features shape: [1, 1024]
+ image_features = image_features / image_features.norm(
+ dim=-1, keepdim=True)
+ text_features = text_features / text_features.norm(
+ dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logit_scale = self.logit_scale.exp()
+ logits_per_iamge = logit_scale * image_features @ text_features.t()
+ logits_per_text = logit_scale * text_features @ image_features.t()
+
+ # shape = [global_batch_size, global_batch_size]
+ return logits_per_iamge, logits_per_text
+
+
+def convert_weights(model: nn.Module):
+ """Convert applicable model parameters to fp16."""
+
+ def _convert_weights_to_fp16(layer):
+ if isinstance(layer, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+ layer.weight.data = layer.weight.data.half()
+ if layer.bias is not None:
+ layer.bias.data = layer.bias.data.half()
+
+ if isinstance(layer, nn.MultiheadAttention):
+ for attr in [
+ *[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']],
+ 'in_proj_bias', 'bias_k', 'bias_v'
+ ]:
+ tensor = getattr(layer, attr)
+ if tensor is not None:
+ tensor.data = tensor.data.half()
+
+ for name in ['text_projection', 'proj']:
+ if hasattr(layer, name):
+ attr = getattr(layer, name)
+ if attr is not None:
+ attr.data = attr.data.half()
+
+ model.apply(_convert_weights_to_fp16)
+
+
+def build_model(state_dict: dict, prompt_depth=0, prompt_length=0):
+ """Build a CLIP model from given pretrained weights."""
+ vit = 'visual.proj' in state_dict
+
+ if vit:
+ vision_width = state_dict['visual.conv1.weight'].shape[0]
+ vision_layers = len([
+ k for k in state_dict.keys()
+ if k.startswith('visual.') and k.endswith('.attn.in_proj_weight')
+ ])
+ vision_patch_size = state_dict['visual.conv1.weight'].shape[-1]
+ grid_size = round(
+ (state_dict['visual.positional_embedding'].shape[0] - 1)**0.5)
+ image_resolution = vision_patch_size * grid_size
+ else:
+ counts: list = [
+ len({
+ k.split('.')[2]
+ for k in state_dict if k.startswith(f'visual.layer{b}')
+ }) for b in [1, 2, 3, 4]
+ ]
+ vision_layers = tuple(counts)
+ vision_width = state_dict['visual.layer1.0.conv1.weight'].shape[0]
+ output_width = round(
+ (state_dict['visual.attnpool.positional_embedding'].shape[0] -
+ 1)**0.5)
+ vision_patch_size = None
+ assert output_width**2 + 1 == state_dict[
+ 'visual.attnpool.positional_embedding'].shape[0]
+ image_resolution = output_width * 32
+
+ embed_dim = state_dict['text_projection'].shape[1]
+ context_length = state_dict['positional_embedding'].shape[0]
+ vocab_size = state_dict['token_embedding.weight'].shape[0]
+ transformer_width = state_dict['ln_final.weight'].shape[0]
+ transformer_heads = transformer_width // 64
+ transformer_layers = len({
+ k.split('.')[2]
+ for k in state_dict if k.startswith('transformer.resblocks')
+ })
+
+ model = CLIP(
+ embed_dim,
+ image_resolution,
+ vision_layers,
+ vision_width,
+ vision_patch_size,
+ context_length,
+ vocab_size,
+ transformer_width,
+ transformer_heads,
+ transformer_layers,
+ prompt_depth=prompt_depth,
+ prompt_length=prompt_length,
+ )
+
+ for key in ['input_resolution', 'context_length', 'vocab_size']:
+ del state_dict[key]
+
+ convert_weights(model)
+ model.load_state_dict(state_dict, strict=False)
+ return model.eval()
diff --git a/projects/CAT-Seg/cat_seg/utils/clip_templates.py b/projects/CAT-Seg/cat_seg/utils/clip_templates.py
new file mode 100644
index 00000000000..bfc32dfc56f
--- /dev/null
+++ b/projects/CAT-Seg/cat_seg/utils/clip_templates.py
@@ -0,0 +1,204 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+# Source: https://github.com/openai/CLIP.
+
+IMAGENET_TEMPLATES = [
+ 'a bad photo of a {}.',
+ 'a photo of many {}.',
+ 'a sculpture of a {}.',
+ 'a photo of the hard to see {}.',
+ 'a low resolution photo of the {}.',
+ 'a rendering of a {}.',
+ 'graffiti of a {}.',
+ 'a bad photo of the {}.',
+ 'a cropped photo of the {}.',
+ 'a tattoo of a {}.',
+ 'the embroidered {}.',
+ 'a photo of a hard to see {}.',
+ 'a bright photo of a {}.',
+ 'a photo of a clean {}.',
+ 'a photo of a dirty {}.',
+ 'a dark photo of the {}.',
+ 'a drawing of a {}.',
+ 'a photo of my {}.',
+ 'the plastic {}.',
+ 'a photo of the cool {}.',
+ 'a close-up photo of a {}.',
+ 'a black and white photo of the {}.',
+ 'a painting of the {}.',
+ 'a painting of a {}.',
+ 'a pixelated photo of the {}.',
+ 'a sculpture of the {}.',
+ 'a bright photo of the {}.',
+ 'a cropped photo of a {}.',
+ 'a plastic {}.',
+ 'a photo of the dirty {}.',
+ 'a jpeg corrupted photo of a {}.',
+ 'a blurry photo of the {}.',
+ 'a photo of the {}.',
+ 'a good photo of the {}.',
+ 'a rendering of the {}.',
+ 'a {} in a video game.',
+ 'a photo of one {}.',
+ 'a doodle of a {}.',
+ 'a close-up photo of the {}.',
+ 'a photo of a {}.',
+ 'the origami {}.',
+ 'the {} in a video game.',
+ 'a sketch of a {}.',
+ 'a doodle of the {}.',
+ 'a origami {}.',
+ 'a low resolution photo of a {}.',
+ 'the toy {}.',
+ 'a rendition of the {}.',
+ 'a photo of the clean {}.',
+ 'a photo of a large {}.',
+ 'a rendition of a {}.',
+ 'a photo of a nice {}.',
+ 'a photo of a weird {}.',
+ 'a blurry photo of a {}.',
+ 'a cartoon {}.',
+ 'art of a {}.',
+ 'a sketch of the {}.',
+ 'a embroidered {}.',
+ 'a pixelated photo of a {}.',
+ 'itap of the {}.',
+ 'a jpeg corrupted photo of the {}.',
+ 'a good photo of a {}.',
+ 'a plushie {}.',
+ 'a photo of the nice {}.',
+ 'a photo of the small {}.',
+ 'a photo of the weird {}.',
+ 'the cartoon {}.',
+ 'art of the {}.',
+ 'a drawing of the {}.',
+ 'a photo of the large {}.',
+ 'a black and white photo of a {}.',
+ 'the plushie {}.',
+ 'a dark photo of a {}.',
+ 'itap of a {}.',
+ 'graffiti of the {}.',
+ 'a toy {}.',
+ 'itap of my {}.',
+ 'a photo of a cool {}.',
+ 'a photo of a small {}.',
+ 'a tattoo of the {}.',
+ # 'A photo of a {} in the scene.',
+]
+
+# v1: 59.0875
+IMAGENET_TEMPLATES_SELECT = [
+ 'itap of a {}.',
+ 'a bad photo of the {}.',
+ 'a origami {}.',
+ 'a photo of the large {}.',
+ 'a {} in a video game.',
+ 'art of the {}.',
+ 'a photo of the small {}.',
+ 'A photo of a {} in the scene',
+]
+
+# v9
+IMAGENET_TEMPLATES_SELECT_CLIP = [
+ 'a bad photo of the {}.',
+ 'a photo of the large {}.',
+ 'a photo of the small {}.',
+ 'a cropped photo of a {}.',
+ 'This is a photo of a {}',
+ 'This is a photo of a small {}',
+ 'This is a photo of a medium {}',
+ 'This is a photo of a large {}',
+ 'This is a masked photo of a {}',
+ 'This is a masked photo of a small {}',
+ 'This is a masked photo of a medium {}',
+ 'This is a masked photo of a large {}',
+ 'This is a cropped photo of a {}',
+ 'This is a cropped photo of a small {}',
+ 'This is a cropped photo of a medium {}',
+ 'This is a cropped photo of a large {}',
+ 'A photo of a {} in the scene',
+ 'a bad photo of the {} in the scene',
+ 'a photo of the large {} in the scene',
+ 'a photo of the small {} in the scene',
+ 'a cropped photo of a {} in the scene',
+ 'a photo of a masked {} in the scene',
+ 'There is a {} in the scene',
+ 'There is the {} in the scene',
+ 'This is a {} in the scene',
+ 'This is the {} in the scene',
+ 'This is one {} in the scene',
+ 'There is a masked {} in the scene',
+ 'There is the masked {} in the scene',
+ 'This is a masked {} in the scene',
+ 'This is the masked {} in the scene',
+ 'This is one masked {} in the scene',
+]
+
+# v10, for comparison
+# IMAGENET_TEMPLATES_SELECT_CLIP = [
+# 'a photo of a {}.',
+#
+# 'This is a photo of a {}',
+# 'This is a photo of a small {}',
+# 'This is a photo of a medium {}',
+# 'This is a photo of a large {}',
+#
+# 'This is a photo of a {}',
+# 'This is a photo of a small {}',
+# 'This is a photo of a medium {}',
+# 'This is a photo of a large {}',
+#
+# 'a photo of a {} in the scene',
+# 'a photo of a {} in the scene',
+#
+# 'There is a {} in the scene',
+# 'There is the {} in the scene',
+# 'This is a {} in the scene',
+# 'This is the {} in the scene',
+# 'This is one {} in the scene',
+# ]
+
+ViLD_templates = [
+ 'There is {article} {category} in the scene.',
+ 'There is the {category} in the scene.',
+ 'a photo of {article} {category} in the scene.',
+ 'a photo of the {category} in the scene.',
+ 'a photo of one {category} in the scene.', 'itap of {article} {category}.',
+ 'itap of my {category}.', 'itap of the {category}.',
+ 'a photo of {article} {category}.', 'a photo of my {category}.',
+ 'a photo of the {category}.', 'a photo of one {category}.',
+ 'a photo of many {category}.', 'a good photo of {article} {category}.',
+ 'a good photo of the {category}.', 'a bad photo of {article} {category}.',
+ 'a bad photo of the {category}.', 'a photo of a nice {category}.',
+ 'a photo of the nice {category}.', 'a photo of a cool {category}.',
+ 'a photo of the cool {category}.', 'a photo of a weird {category}.',
+ 'a photo of the weird {category}.', 'a photo of a small {category}.',
+ 'a photo of the small {category}.', 'a photo of a large {category}.',
+ 'a photo of the large {category}.', 'a photo of a clean {category}.',
+ 'a photo of the clean {category}.', 'a photo of a dirty {category}.',
+ 'a photo of the dirty {category}.',
+ 'a bright photo of {article} {category}.',
+ 'a bright photo of the {category}.',
+ 'a dark photo of {article} {category}.', 'a dark photo of the {category}.',
+ 'a photo of a hard to see {category}.',
+ 'a photo of the hard to see {category}.',
+ 'a low resolution photo of {article} {category}.',
+ 'a low resolution photo of the {category}.',
+ 'a cropped photo of {article} {category}.',
+ 'a cropped photo of the {category}.',
+ 'a close-up photo of {article} {category}.',
+ 'a close-up photo of the {category}.',
+ 'a jpeg corrupted photo of {article} {category}.',
+ 'a jpeg corrupted photo of the {category}.',
+ 'a blurry photo of {article} {category}.',
+ 'a blurry photo of the {category}.',
+ 'a pixelated photo of {article} {category}.',
+ 'a pixelated photo of the {category}.',
+ 'a black and white photo of the {category}.',
+ 'a black and white photo of {article} {category}.',
+ 'a plastic {category}.', 'the plastic {category}.', 'a toy {category}.',
+ 'the toy {category}.', 'a plushie {category}.', 'the plushie {category}.',
+ 'a cartoon {category}.', 'the cartoon {category}.',
+ 'an embroidered {category}.', 'the embroidered {category}.',
+ 'a painting of the {category}.', 'a painting of a {category}.'
+]
diff --git a/projects/CAT-Seg/cat_seg/utils/clip_wrapper.py b/projects/CAT-Seg/cat_seg/utils/clip_wrapper.py
new file mode 100644
index 00000000000..f809d2b8280
--- /dev/null
+++ b/projects/CAT-Seg/cat_seg/utils/clip_wrapper.py
@@ -0,0 +1,275 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Referred to: https://github.com/KU-CVLAB/CAT-Seg/blob/main/cat_seg/third_party/clip.py # noqa
+import hashlib
+import os
+import urllib
+import warnings
+from typing import List, Union
+
+import torch
+from PIL import Image
+from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize,
+ ToTensor)
+from tqdm import tqdm
+
+from .clip_model import build_model
+from .tokenizer import SimpleTokenizer as _Tokenizer
+
+__all__ = ['available_models', 'load', 'tokenize']
+_tokenizer = _Tokenizer()
+
+_MODELS = {
+ 'RN50':
+ 'https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt', # noqa
+ 'RN101':
+ 'https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt', # noqa
+ 'RN50x4':
+ 'https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt', # noqa
+ 'RN50x16':
+ 'https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt', # noqa
+ 'RN50x64':
+ 'https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt', # noqa
+ 'ViT-B/32':
+ 'https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt', # noqa
+ 'ViT-B/16':
+ 'https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt', # noqa
+ 'ViT-L/14':
+ 'https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt', # noqa
+ 'ViT-L/14@336px':
+ 'https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt', # noqa
+}
+
+
+def _download(url: str, root: str = os.path.expanduser('~/.cache/clip')):
+ """Download clip pretrained weights."""
+ os.makedirs(root, exist_ok=True)
+ filename = os.path.basename(url)
+
+ expected_sha256 = url.split('/')[-2]
+ download_target = os.path.join(root, filename)
+
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
+ raise RuntimeError(
+ f'{download_target} exists and is not a regular file')
+
+ if os.path.isfile(download_target):
+ if hashlib.sha256(open(download_target,
+ 'rb').read()).hexdigest() == expected_sha256:
+ return download_target
+ else:
+ warnings.warn(
+ f'{download_target} exists, but the SHA256 checksum does not\
+ match; re-downloading the file')
+
+ with urllib.request.urlopen(url) as source, open(download_target,
+ 'wb') as output:
+ with tqdm(
+ total=int(source.info().get('Content-Length')),
+ ncols=80) as loop:
+ while True:
+ buffer = source.read(8192)
+ if not buffer:
+ break
+
+ output.write(buffer)
+ loop.update(len(buffer))
+
+ if hashlib.sha256(open(download_target,
+ 'rb').read()).hexdigest() != expected_sha256:
+ raise RuntimeError(
+ 'Model has been downloaded but the SHA256 checksum does not not\
+ match')
+
+ return download_target
+
+
+def available_models():
+ """Returns a list of available models."""
+ return list(_MODELS.keys())
+
+
+def load(name: str,
+ device: Union[str, torch.device] = 'cuda'
+ if torch.cuda.is_available() else 'cpu',
+ jit=True,
+ prompt_depth=0,
+ prompt_length=0):
+ """Load target clip model."""
+ if name not in _MODELS:
+ raise RuntimeError(
+ f'Model {name} not found; available models = {available_models()}')
+
+ model_path = _download(_MODELS[name])
+ model = torch.jit.load(
+ model_path, map_location=device if jit else 'cpu').eval()
+ n_px = model.input_resolution.item()
+
+ transform = Compose([
+ Resize(n_px, interpolation=Image.BICUBIC),
+ CenterCrop(n_px),
+ lambda image: image.convert('RGB'),
+ ToTensor(),
+ Normalize((0.48145466, 0.4578275, 0.40821073),
+ (0.26862954, 0.26130258, 0.27577711)),
+ ])
+
+ if not jit:
+ model = build_model(model.state_dict(), prompt_depth,
+ prompt_length).to(device)
+ return model, transform
+
+ # patch the device names
+ device_holder = torch.jit.trace(
+ lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
+ device_node = [
+ n for n in device_holder.graph.findAllNodes('prim::Constant')
+ if 'Device' in repr(n)
+ ][-1]
+
+ def patch_device(module):
+ graphs = [module.graph] if hasattr(module, 'graph') else []
+ if hasattr(module, 'forward1'):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes('prim::Constant'):
+ if 'value' in node.attributeNames() and str(
+ node['value']).startswith('cuda'):
+ node.copyAttributes(device_node)
+
+ model.apply(patch_device)
+ patch_device(model.encode_image)
+ patch_device(model.encode_text)
+
+ # patch dtype to float32 on CPU
+ if device == 'cpu':
+ float_holder = torch.jit.trace(
+ lambda: torch.ones([]).float(), example_inputs=[])
+ float_input = list(float_holder.graph.findNode('aten::to').inputs())[1]
+ float_node = float_input.node()
+
+ def patch_float(module):
+ graphs = [module.graph] if hasattr(module, 'graph') else []
+ if hasattr(module, 'forward1'):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes('aten::to'):
+ inputs = list(node.inputs())
+ for i in [1, 2]:
+ # dtype can be the second or third argument to
+ # aten::to()
+ if inputs[i].node()['value'] == 5:
+ inputs[i].node().copyAttributes(float_node)
+
+ model.apply(patch_float)
+ patch_float(model.encode_image)
+ patch_float(model.encode_text)
+
+ model.float()
+
+ return model, transform
+
+
+def load_custom(name: str,
+ device: Union[str, torch.device] = 'cuda'
+ if torch.cuda.is_available() else 'cpu',
+ jit=True,
+ n_px=224):
+ """Load a customized clip model."""
+ if name not in _MODELS:
+ raise RuntimeError(
+ f'Model {name} not found; available models = {available_models()}')
+
+ model_path = _download(_MODELS[name])
+ model = torch.jit.load(
+ model_path, map_location=device if jit else 'cpu').eval()
+ # n_px = model.input_resolution.item()
+
+ transform = Compose([
+ Resize(n_px, interpolation=Image.BICUBIC),
+ CenterCrop(n_px),
+ lambda image: image.convert('RGB'),
+ ToTensor(),
+ Normalize((0.48145466, 0.4578275, 0.40821073),
+ (0.26862954, 0.26130258, 0.27577711)),
+ ])
+
+ if not jit:
+ model = build_model(model.state_dict()).to(device)
+ return model, transform
+
+ # patch the device names
+ device_holder = torch.jit.trace(
+ lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
+ device_node = [
+ n for n in device_holder.graph.findAllNodes('prim::Constant')
+ if 'Device' in repr(n)
+ ][-1]
+
+ def patch_device(module):
+ graphs = [module.graph] if hasattr(module, 'graph') else []
+ if hasattr(module, 'forward1'):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes('prim::Constant'):
+ if 'value' in node.attributeNames() and str(
+ node['value']).startswith('cuda'):
+ node.copyAttributes(device_node)
+
+ model.apply(patch_device)
+ patch_device(model.encode_image)
+ patch_device(model.encode_text)
+
+ # patch dtype to float32 on CPU
+ if device == 'cpu':
+ float_holder = torch.jit.trace(
+ lambda: torch.ones([]).float(), example_inputs=[])
+ float_input = list(float_holder.graph.findNode('aten::to').inputs())[1]
+ float_node = float_input.node()
+
+ def patch_float(module):
+ graphs = [module.graph] if hasattr(module, 'graph') else []
+ if hasattr(module, 'forward1'):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes('aten::to'):
+ inputs = list(node.inputs())
+ for i in [
+ 1, 2
+ ]: # dtype can be the second or third argument to
+ # aten::to()
+ if inputs[i].node()['value'] == 5:
+ inputs[i].node().copyAttributes(float_node)
+
+ model.apply(patch_float)
+ patch_float(model.encode_image)
+ patch_float(model.encode_text)
+
+ model.float()
+
+ return model, transform
+
+
+def tokenize(texts: Union[str, List[str]], context_length: int = 77):
+ """Convert texts to tokens."""
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = _tokenizer.encoder['<|startoftext|>']
+ eot_token = _tokenizer.encoder['<|endoftext|>']
+ # encode each template text phrase
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
+ for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > context_length:
+ raise RuntimeError(
+ f'Input {texts[i]} is too long for context length\
+ {context_length}')
+ result[i, :len(tokens)] = torch.tensor(tokens)
+
+ return result
diff --git a/projects/CAT-Seg/cat_seg/utils/self_attention_block.py b/projects/CAT-Seg/cat_seg/utils/self_attention_block.py
new file mode 100644
index 00000000000..1c06cbd99e0
--- /dev/null
+++ b/projects/CAT-Seg/cat_seg/utils/self_attention_block.py
@@ -0,0 +1,79 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+
+class LinearAttention(nn.Module):
+ """Multi-Head linear attention proposed in "Transformers are RNNs".
+
+ Source: https://github.com/KU-CVLAB/CAT-Seg/blob/main/cat_seg/modeling/transformer/model.py#L247 # noqa
+ """
+
+ def __init__(self, eps=1e-6):
+ super().__init__()
+ self.eps = eps
+
+ def forward(self, queries, keys, values):
+ """
+ Args:
+ queries: [N, L, H, D]
+ keys: [N, S, H, D]
+ values: [N, S, H, D]
+ q_mask: [N, L]
+ kv_mask: [N, S]
+ Returns:
+ queried_values: (N, L, H, D)
+ """
+ Q = F.elu(queries) + 1
+ K = F.elu(keys) + 1
+
+ v_length = values.size(1)
+ values = values / v_length # prevent fp16 overflow
+ KV = torch.einsum('nshd,nshv->nhdv', K, values) # (S,D)' @ S,V
+ Z = 1 / (torch.einsum('nlhd,nhd->nlh', Q, K.sum(dim=1)) + self.eps)
+ queried_values = torch.einsum('nlhd,nhdv,nlh->nlhv', Q, KV,
+ Z) * v_length
+
+ return queried_values.contiguous()
+
+
+class FullAttention(nn.Module):
+ """Multi-head scaled dot-product attention, a.k.a full attention.
+
+ Source: https://github.com/KU-CVLAB/CAT-Seg/blob/main/cat_seg/modeling/transformer/model.py#L276 # noqa
+ """
+
+ def __init__(self, use_dropout=False, attention_dropout=0.1):
+ super().__init__()
+ self.use_dropout = use_dropout
+ self.dropout = nn.Dropout(attention_dropout)
+
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
+ """
+ Args:
+ queries: [N, L, H, D]
+ keys: [N, S, H, D]
+ values: [N, S, H, D]
+ q_mask: [N, L]
+ kv_mask: [N, S]
+ Returns:
+ queried_values: (N, L, H, D)
+ """
+
+ # Compute the unnormalized attention and apply the masks
+ QK = torch.einsum('nlhd,nshd->nlsh', queries, keys)
+ if kv_mask is not None:
+ QK.masked_fill_(
+ ~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]),
+ float('-inf'))
+
+ # Compute the attention and the weighted average
+ softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
+ A = torch.softmax(softmax_temp * QK, dim=2)
+ if self.use_dropout:
+ A = self.dropout(A)
+
+ queried_values = torch.einsum('nlsh,nshd->nlhd', A, values)
+
+ return queried_values.contiguous()
diff --git a/projects/CAT-Seg/cat_seg/utils/tokenizer.py b/projects/CAT-Seg/cat_seg/utils/tokenizer.py
new file mode 100644
index 00000000000..c84711b0678
--- /dev/null
+++ b/projects/CAT-Seg/cat_seg/utils/tokenizer.py
@@ -0,0 +1,160 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import gzip
+import html
+import os
+from functools import lru_cache
+
+import ftfy
+import regex as re
+
+
+@lru_cache()
+def default_bpe():
+ """Return default BPE vocabulary path."""
+ return os.path.join(
+ os.path.dirname(os.path.abspath(__file__)),
+ 'bpe_vocab/bpe_simple_vocab_16e6.txt.gz')
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """Returns list of utf-8 byte and a corresponding list of unicode strings.
+
+ The reversible bpe codes work on unicode strings. This means you need a
+ large # of unicode characters in your vocab if you want to avoid UNKs. When
+ you're at something like a 10B token dataset you end up needing around 5K
+ for decent coverage. This is a significant percentage of your normal, say,
+ 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and
+ unicode strings. And avoids mapping to whitespace/control characters the
+ bpe code barfs on.
+ """
+ bs = list(range(ord('!'),
+ ord('~') + 1)) + list(range(
+ ord('¡'),
+ ord('¬') + 1)) + list(range(ord('®'),
+ ord('ÿ') + 1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+
+ Word is represented as tuple of symbols (symbols being variable-length
+ strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ """Clean string."""
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ """Clean whitespace in string."""
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer:
+ """Customized Tokenizer implementation."""
+
+ def __init__(self, bpe_path: str = default_bpe()):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
+ merges = merges[1:49152 - 256 - 2 + 1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v + '' for v in vocab]
+ for merge in merges:
+ vocab.append(''.join(merge))
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {
+ '<|startoftext|>': '<|startoftext|>',
+ '<|endoftext|>': '<|endoftext|>'
+ }
+ self.pat = re.compile(
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|\
+ 'll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
+
+ def bpe(self, token):
+ """Refer to bpe vocabulary dictionary."""
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + (token[-1] + '', )
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token + ''
+
+ while True:
+ bigram = min(
+ pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except ValueError:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word) - 1 and word[
+ i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ """Encode text strings."""
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = ''.join(self.byte_encoder[b]
+ for b in token.encode('utf-8'))
+ bpe_tokens.extend(self.encoder[bpe_token]
+ for bpe_token in self.bpe(token).split(' '))
+ return bpe_tokens
+
+ def decode(self, tokens):
+ """Decoder tokens to strings."""
+ text = ''.join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode(
+ 'utf-8', errors='replace').replace('', ' ')
+ return text
diff --git a/projects/CAT-Seg/configs/_base_/datasets/ade20k_384x384.py b/projects/CAT-Seg/configs/_base_/datasets/ade20k_384x384.py
new file mode 100644
index 00000000000..488ba3d7f6f
--- /dev/null
+++ b/projects/CAT-Seg/configs/_base_/datasets/ade20k_384x384.py
@@ -0,0 +1,68 @@
+# dataset settings
+dataset_type = 'ADE20KDataset'
+data_root = 'data/ade/ADEChallengeData2016'
+crop_size = (384, 384)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(
+ type='RandomResize',
+ scale=(2048, 512),
+ ratio_range=(0.5, 2.0),
+ keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=(2048, 512), keep_ratio=True),
+ # add loading annotation after ``Resize`` because ground truth
+ # does not need to do resize data transform
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(type='PackSegInputs')
+]
+img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
+tta_pipeline = [
+ dict(type='LoadImageFromFile', backend_args=None),
+ dict(
+ type='TestTimeAug',
+ transforms=[
+ [
+ dict(type='Resize', scale_factor=r, keep_ratio=True)
+ for r in img_ratios
+ ],
+ [
+ dict(type='RandomFlip', prob=0., direction='horizontal'),
+ dict(type='RandomFlip', prob=1., direction='horizontal')
+ ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
+ ])
+]
+train_dataloader = dict(
+ batch_size=4,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='images/training', seg_map_path='annotations/training'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='images/validation',
+ seg_map_path='annotations/validation'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
diff --git a/projects/CAT-Seg/configs/_base_/datasets/coco-stuff164k_384x384.py b/projects/CAT-Seg/configs/_base_/datasets/coco-stuff164k_384x384.py
new file mode 100644
index 00000000000..dd051761d47
--- /dev/null
+++ b/projects/CAT-Seg/configs/_base_/datasets/coco-stuff164k_384x384.py
@@ -0,0 +1,62 @@
+# dataset settings
+dataset_type = 'COCOStuffDataset'
+data_root = 'data/coco_stuff164k'
+crop_size = (384, 384)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=(2048, 512), keep_ratio=True),
+ # add loading annotation after ``Resize`` because ground truth
+ # does not need to do resize data transform
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
+tta_pipeline = [
+ dict(type='LoadImageFromFile', backend_args=None),
+ dict(
+ type='TestTimeAug',
+ transforms=[
+ [
+ dict(type='Resize', scale_factor=r, keep_ratio=True)
+ for r in img_ratios
+ ],
+ [
+ dict(type='RandomFlip', prob=0., direction='horizontal'),
+ dict(type='RandomFlip', prob=1., direction='horizontal')
+ ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
+ ])
+]
+train_dataloader = dict(
+ batch_size=2,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='images/train2017', seg_map_path='annotations/train2017'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='images/val2017', seg_map_path='annotations/val2017'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
diff --git a/projects/CAT-Seg/configs/_base_/datasets/pascal_context_59_384x384.py b/projects/CAT-Seg/configs/_base_/datasets/pascal_context_59_384x384.py
new file mode 100644
index 00000000000..250c5990f6d
--- /dev/null
+++ b/projects/CAT-Seg/configs/_base_/datasets/pascal_context_59_384x384.py
@@ -0,0 +1,72 @@
+# dataset settings
+dataset_type = 'PascalContextDataset59'
+data_root = 'data/VOCdevkit/VOC2010/'
+
+img_scale = (520, 520)
+crop_size = (384, 384)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(
+ type='RandomResize',
+ scale=img_scale,
+ ratio_range=(0.5, 2.0),
+ keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=True),
+ # add loading annotation after ``Resize`` because ground truth
+ # does not need to do resize data transform
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(type='PackSegInputs')
+]
+img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
+tta_pipeline = [
+ dict(type='LoadImageFromFile', backend_args=None),
+ dict(
+ type='TestTimeAug',
+ transforms=[
+ [
+ dict(type='Resize', scale_factor=r, keep_ratio=True)
+ for r in img_ratios
+ ],
+ [
+ dict(type='RandomFlip', prob=0., direction='horizontal'),
+ dict(type='RandomFlip', prob=1., direction='horizontal')
+ ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
+ ])
+]
+train_dataloader = dict(
+ batch_size=4,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='JPEGImages', seg_map_path='SegmentationClassContext'),
+ ann_file='ImageSets/SegmentationContext/train.txt',
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='JPEGImages', seg_map_path='SegmentationClassContext'),
+ ann_file='ImageSets/SegmentationContext/val.txt',
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
diff --git a/projects/CAT-Seg/configs/_base_/default_runtime.py b/projects/CAT-Seg/configs/_base_/default_runtime.py
new file mode 100644
index 00000000000..272b4d24679
--- /dev/null
+++ b/projects/CAT-Seg/configs/_base_/default_runtime.py
@@ -0,0 +1,15 @@
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+vis_backends = [dict(type='LocalVisBackend')]
+visualizer = dict(
+ type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+log_processor = dict(by_epoch=False)
+log_level = 'INFO'
+load_from = None
+resume = False
+
+tta_model = dict(type='SegTTAModel')
diff --git a/projects/CAT-Seg/configs/_base_/schedules/schedule_80k.py b/projects/CAT-Seg/configs/_base_/schedules/schedule_80k.py
new file mode 100644
index 00000000000..0dcd6c4d1bc
--- /dev/null
+++ b/projects/CAT-Seg/configs/_base_/schedules/schedule_80k.py
@@ -0,0 +1,24 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
+optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
+# learning policy
+param_scheduler = [
+ dict(
+ type='PolyLR',
+ eta_min=1e-4,
+ power=0.9,
+ begin=0,
+ end=80000,
+ by_epoch=False)
+]
+# training schedule for 80k
+train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=8000)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=8000),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'))
diff --git a/projects/CAT-Seg/configs/cat_seg/catseg_vitb-r101_4xb1-warmcoslr2e-4-adamw-80k_ade20k-384x384.py b/projects/CAT-Seg/configs/cat_seg/catseg_vitb-r101_4xb1-warmcoslr2e-4-adamw-80k_ade20k-384x384.py
new file mode 100644
index 00000000000..bab43a6a39d
--- /dev/null
+++ b/projects/CAT-Seg/configs/cat_seg/catseg_vitb-r101_4xb1-warmcoslr2e-4-adamw-80k_ade20k-384x384.py
@@ -0,0 +1,103 @@
+_base_ = [
+ '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py',
+ '../_base_/datasets/ade20k_384x384.py'
+]
+
+custom_imports = dict(imports=['cat_seg'])
+
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+crop_size = (384, 384)
+data_preprocessor = dict(
+ type='SegDataPreProcessor',
+ size=crop_size,
+ # due to the clip model, we do normalization in backbone forward()
+ bgr_to_rgb=True,
+ pad_val=0,
+ seg_pad_val=255)
+# model_cfg
+model = dict(
+ type='EncoderDecoder',
+ data_preprocessor=data_preprocessor,
+ backbone=dict(
+ type='CLIPOVCATSeg',
+ feature_extractor=dict(
+ type='ResNet',
+ depth=101,
+ # only use the first three layers
+ num_stages=3,
+ out_indices=(0, 1, 2),
+ dilations=(1, 1, 1),
+ strides=(1, 2, 2),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True,
+ init_cfg=dict(
+ type='Pretrained', checkpoint='torchvision://resnet101'),
+ ),
+ train_class_json='data/ade150.json',
+ test_class_json='data/ade150.json',
+ clip_pretrained='ViT-B/16',
+ clip_finetune='attention',
+ ),
+ neck=dict(
+ type='CATSegAggregator',
+ appearance_guidance_dim=1024,
+ num_layers=2,
+ pooling_size=(1, 1),
+ ),
+ decode_head=dict(
+ type='CATSegHead',
+ in_channels=128,
+ channels=128,
+ num_classes=150,
+ embed_dims=128,
+ decoder_dims=(64, 32),
+ decoder_guidance_dims=(512, 256),
+ decoder_guidance_proj_dims=(32, 16),
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0,
+ avg_non_ignore=True)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', stride=crop_size, crop_size=crop_size))
+
+# dataset settings
+train_dataloader = dict(
+ batch_size=2,
+ num_workers=4,
+)
+
+# training schedule for 80k
+train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=4000)
+
+default_hooks = dict(
+ checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000),
+ visualization=dict(type='SegVisualizationHook', draw=True, interval=4000))
+
+# optimizer
+optim_wrapper = dict(
+ _delete_=True,
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.0001),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'backbone.feature_extractor': dict(lr_mult=0.01),
+ 'backbone.clip_model.visual': dict(lr_mult=0.01)
+ }))
+
+# learning policy
+param_scheduler = [
+ # Use a linear warm-up at [0, 100) iterations
+ dict(type='LinearLR', start_factor=0.01, by_epoch=False, begin=0, end=500),
+ # Use a cosine learning rate at [100, 900) iterations
+ dict(
+ type='CosineAnnealingLR',
+ T_max=79500,
+ by_epoch=False,
+ begin=500,
+ end=80000),
+]
diff --git a/projects/CAT-Seg/configs/cat_seg/catseg_vitb-r101_4xb1-warmcoslr2e-4-adamw-80k_pascal-context-59-384x384.py b/projects/CAT-Seg/configs/cat_seg/catseg_vitb-r101_4xb1-warmcoslr2e-4-adamw-80k_pascal-context-59-384x384.py
new file mode 100644
index 00000000000..8b412cb86fb
--- /dev/null
+++ b/projects/CAT-Seg/configs/cat_seg/catseg_vitb-r101_4xb1-warmcoslr2e-4-adamw-80k_pascal-context-59-384x384.py
@@ -0,0 +1,103 @@
+_base_ = [
+ '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py',
+ '../_base_/datasets/pascal_context_59_384x384.py'
+]
+
+custom_imports = dict(imports=['cat_seg'])
+
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+crop_size = (384, 384)
+data_preprocessor = dict(
+ type='SegDataPreProcessor',
+ size=crop_size,
+ # due to the clip model, we do normalization in backbone forward()
+ bgr_to_rgb=True,
+ pad_val=0,
+ seg_pad_val=255)
+# model_cfg
+model = dict(
+ type='EncoderDecoder',
+ data_preprocessor=data_preprocessor,
+ backbone=dict(
+ type='CLIPOVCATSeg',
+ feature_extractor=dict(
+ type='ResNet',
+ depth=101,
+ # only use the first three layers
+ num_stages=3,
+ out_indices=(0, 1, 2),
+ dilations=(1, 1, 1),
+ strides=(1, 2, 2),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True,
+ init_cfg=dict(
+ type='Pretrained', checkpoint='torchvision://resnet101'),
+ ),
+ train_class_json='data/pc59.json',
+ test_class_json='data/pc59.json',
+ clip_pretrained='ViT-B/16',
+ clip_finetune='attention',
+ ),
+ neck=dict(
+ type='CATSegAggregator',
+ appearance_guidance_dim=1024,
+ num_layers=2,
+ pooling_size=(1, 1),
+ ),
+ decode_head=dict(
+ type='CATSegHead',
+ in_channels=128,
+ channels=128,
+ num_classes=59,
+ embed_dims=128,
+ decoder_dims=(64, 32),
+ decoder_guidance_dims=(512, 256),
+ decoder_guidance_proj_dims=(32, 16),
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0,
+ avg_non_ignore=True)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', stride=crop_size, crop_size=crop_size))
+
+# dataset settings
+train_dataloader = dict(
+ batch_size=2,
+ num_workers=4,
+)
+
+# training schedule for 80k
+train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=4000)
+
+default_hooks = dict(
+ checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000),
+ visualization=dict(type='SegVisualizationHook', draw=True, interval=4000))
+
+# optimizer
+optim_wrapper = dict(
+ _delete_=True,
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.0001),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'backbone.feature_extractor': dict(lr_mult=0.01),
+ 'backbone.clip_model.visual': dict(lr_mult=0.01)
+ }))
+
+# learning policy
+param_scheduler = [
+ # Use a linear warm-up at [0, 100) iterations
+ dict(type='LinearLR', start_factor=0.01, by_epoch=False, begin=0, end=500),
+ # Use a cosine learning rate at [100, 900) iterations
+ dict(
+ type='CosineAnnealingLR',
+ T_max=79500,
+ by_epoch=False,
+ begin=500,
+ end=80000),
+]
diff --git a/projects/CAT-Seg/configs/cat_seg/catseg_vitb-r101_4xb2-warmcoslr2e-4-adamw-80k_coco-stuff164k-384x384.py b/projects/CAT-Seg/configs/cat_seg/catseg_vitb-r101_4xb2-warmcoslr2e-4-adamw-80k_coco-stuff164k-384x384.py
new file mode 100644
index 00000000000..52bf712feae
--- /dev/null
+++ b/projects/CAT-Seg/configs/cat_seg/catseg_vitb-r101_4xb2-warmcoslr2e-4-adamw-80k_coco-stuff164k-384x384.py
@@ -0,0 +1,102 @@
+_base_ = [
+ '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py',
+ '../_base_/datasets/coco-stuff164k_384x384.py'
+]
+
+custom_imports = dict(imports=['cat_seg'])
+
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+crop_size = (384, 384)
+data_preprocessor = dict(
+ type='SegDataPreProcessor',
+ size=crop_size,
+ # due to the clip model, we do normalization in backbone forward()
+ bgr_to_rgb=True,
+ pad_val=0,
+ seg_pad_val=255)
+# model_cfg
+model = dict(
+ type='EncoderDecoder',
+ data_preprocessor=data_preprocessor,
+ backbone=dict(
+ type='CLIPOVCATSeg',
+ feature_extractor=dict(
+ type='ResNet',
+ depth=101,
+ # only use the first three layers
+ num_stages=3,
+ out_indices=(0, 1, 2),
+ dilations=(1, 1, 1),
+ strides=(1, 2, 2),
+ norm_cfg=norm_cfg,
+ norm_eval=False,
+ style='pytorch',
+ contract_dilation=True,
+ init_cfg=dict(
+ type='Pretrained', checkpoint='torchvision://resnet101'),
+ ),
+ train_class_json='data/coco.json',
+ test_class_json='data/coco.json',
+ clip_pretrained='ViT-B/16',
+ clip_finetune='attention',
+ ),
+ neck=dict(
+ type='CATSegAggregator',
+ appearance_guidance_dim=1024,
+ num_layers=2,
+ ),
+ decode_head=dict(
+ type='CATSegHead',
+ in_channels=128,
+ channels=128,
+ num_classes=171,
+ embed_dims=128,
+ decoder_dims=(64, 32),
+ decoder_guidance_dims=(512, 256),
+ decoder_guidance_proj_dims=(32, 16),
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0,
+ avg_non_ignore=True)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='slide', stride=crop_size, crop_size=crop_size))
+
+# dataset settings
+train_dataloader = dict(
+ batch_size=2,
+ num_workers=4,
+)
+
+# training schedule for 80k
+train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=4000)
+
+default_hooks = dict(
+ checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000),
+ visualization=dict(type='SegVisualizationHook', draw=True, interval=4000))
+
+# optimizer
+optim_wrapper = dict(
+ _delete_=True,
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.0001),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'backbone.feature_extractor': dict(lr_mult=0.01),
+ 'backbone.clip_model.visual': dict(lr_mult=0.01)
+ }))
+
+# learning policy
+param_scheduler = [
+ # Use a linear warm-up at [0, 100) iterations
+ dict(type='LinearLR', start_factor=0.01, by_epoch=False, begin=0, end=500),
+ # Use a cosine learning rate at [100, 900) iterations
+ dict(
+ type='CosineAnnealingLR',
+ T_max=79500,
+ by_epoch=False,
+ begin=500,
+ end=80000),
+]
diff --git a/projects/CAT-Seg/configs/cat_seg/catseg_vitg-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py b/projects/CAT-Seg/configs/cat_seg/catseg_vitg-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py
new file mode 100644
index 00000000000..345945d0284
--- /dev/null
+++ b/projects/CAT-Seg/configs/cat_seg/catseg_vitg-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py
@@ -0,0 +1,11 @@
+_base_ = './catseg_vitl-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py' # noqa
+
+model = dict(
+ backbone=dict(
+ type='CLIPOVCATSeg',
+ clip_pretrained='ViT-G',
+ custom_clip_weights='~/CLIP-ViT-bigG-14-laion2B-39B-b160k'),
+ neck=dict(
+ text_guidance_dim=1280,
+ appearance_guidance_dim=512,
+ ))
diff --git a/projects/CAT-Seg/configs/cat_seg/catseg_vith-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py b/projects/CAT-Seg/configs/cat_seg/catseg_vith-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py
new file mode 100644
index 00000000000..2f09b8c9ca2
--- /dev/null
+++ b/projects/CAT-Seg/configs/cat_seg/catseg_vith-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py
@@ -0,0 +1,11 @@
+_base_ = './catseg_vitl-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py' # noqa
+
+model = dict(
+ backbone=dict(
+ type='CLIPOVCATSeg',
+ clip_pretrained='ViT-H',
+ custom_clip_weights='~/CLIP-ViT-H-14-laion2B-s32B-b79K'),
+ neck=dict(
+ text_guidance_dim=1024,
+ appearance_guidance_dim=512,
+ ))
diff --git a/projects/CAT-Seg/configs/cat_seg/catseg_vitl-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py b/projects/CAT-Seg/configs/cat_seg/catseg_vitl-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py
new file mode 100644
index 00000000000..bb4d57ae219
--- /dev/null
+++ b/projects/CAT-Seg/configs/cat_seg/catseg_vitl-swin-b_4xb1-warmcoslr2e-4_adamw-80k_coco-stuff164k_384x384.py
@@ -0,0 +1,72 @@
+_base_ = './catseg_vitb-r101_4xb2-warmcoslr2e-4-adamw-80k_coco-stuff164k-384x384.py' # noqa
+
+pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_20220317-55b0104a.pth' # noqa
+crop_size = (384, 384)
+data_preprocessor = dict(size=crop_size)
+model = dict(
+ backbone=dict(
+ type='CLIPOVCATSeg',
+ feature_extractor=dict(
+ _delete_=True,
+ type='SwinTransformer',
+ pretrain_img_size=384,
+ embed_dims=128,
+ depths=[2, 2, 18],
+ num_heads=[4, 8, 16],
+ window_size=12,
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.3,
+ patch_norm=True,
+ out_indices=(0, 1, 2),
+ init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
+ clip_pretrained='ViT-L/14@336px',
+ ),
+ neck=dict(
+ text_guidance_dim=768,
+ appearance_guidance_dim=512,
+ ),
+ decode_head=dict(
+ embed_dims=128,
+ decoder_guidance_dims=(256, 128),
+ ))
+
+# dataset settings
+train_dataloader = dict(
+ batch_size=1,
+ num_workers=2,
+)
+
+# training schedule for 80k
+train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=4000)
+
+default_hooks = dict(
+ visualization=dict(type='SegVisualizationHook', draw=True, interval=4000))
+
+# optimizer
+optim_wrapper = dict(
+ _delete_=True,
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.0001),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'backbone.feature_extractor': dict(lr_mult=0.01),
+ 'backbone.clip_model.visual': dict(lr_mult=0.01)
+ }))
+
+# learning policy
+param_scheduler = [
+ # Use a linear warm-up at [0, 100) iterations
+ dict(type='LinearLR', start_factor=0.01, by_epoch=False, begin=0, end=500),
+ # Use a cosine learning rate at [100, 900) iterations
+ dict(
+ type='CosineAnnealingLR',
+ T_max=79500,
+ by_epoch=False,
+ begin=500,
+ end=80000),
+]
diff --git a/projects/CAT-Seg/utils/__init__.py b/projects/CAT-Seg/utils/__init__.py
new file mode 100644
index 00000000000..02d85f29cb1
--- /dev/null
+++ b/projects/CAT-Seg/utils/__init__.py
@@ -0,0 +1,7 @@
+from .clip_templates import (IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT,
+ IMAGENET_TEMPLATES_SELECT_CLIP, ViLD_templates)
+
+__all__ = [
+ 'IMAGENET_TEMPLATES', 'IMAGENET_TEMPLATES_SELECT',
+ 'IMAGENET_TEMPLATES_SELECT_CLIP', 'ViLD_templates'
+]
diff --git a/projects/XDecoder/README.md b/projects/XDecoder/README.md
new file mode 100644
index 00000000000..3d55575c6b8
--- /dev/null
+++ b/projects/XDecoder/README.md
@@ -0,0 +1,17 @@
+# X-Decoder
+
+> [X-Decoder: Generalized Decoding for Pixel, Image, and Language](https://arxiv.org/pdf/2212.11270.pdf)
+
+
+
+## Abstract
+
+We present X-Decoder, a generalized decoding model that can predict pixel-level segmentation and language tokens seamlessly. X-Decodert takes as input two types of queries: (i) generic non-semantic queries and (ii) semantic queries induced from text inputs, to decode different pixel-level and token-level outputs in the same semantic space. With such a novel design, X-Decoder is the first work that provides a unified way to support all types of image segmentation and a variety of vision-language (VL) tasks. Further, our design enables seamless interactions across tasks at different granularities and brings mutual benefits by learning a common and rich pixel-level visual-semantic understanding space, without any pseudo-labeling. After pretraining on a mixed set of a limited amount of segmentation data and millions of image-text pairs, X-Decoder exhibits strong transferability to a wide range of downstream tasks in both zero-shot and finetuning settings. Notably, it achieves (1) state-of-the-art results on open-vocabulary segmentation and referring segmentation on eight datasets; (2) better or competitive finetuned performance to other generalist and specialist models on segmentation and VL tasks; and (3) flexibility for efficient finetuning and novel task composition (e.g., referring captioning and image editing).
+
+
+
+
+
+## Usage
+
+We implement it based on [mmdetection](https://github.com/open-mmlab/mmdetection/), please refer to [mmdetection/projects/XDecoder](https://github.com/open-mmlab/mmdetection/tree/main/projects/XDecoder) for more details.
diff --git a/projects/bdd100k_dataset/README.md b/projects/bdd100k_dataset/README.md
new file mode 100644
index 00000000000..c7745258441
--- /dev/null
+++ b/projects/bdd100k_dataset/README.md
@@ -0,0 +1,50 @@
+# BDD100K Dataset
+
+Support **`BDD100K Dataset`**
+
+## Description
+
+Author: CastleDream
+
+This project implements **`BDD100K Dataset`**
+
+### Dataset preparing
+
+Preparing `BDD100K Dataset` dataset following [BDD100K Dataset Preparing Guide](https://github.com/open-mmlab/mmsegmentation/tree/main/projects/mapillary_dataset/docs/en/user_guides/2_dataset_prepare.md#bdd100k)
+
+```none
+mmsegmentation/data
+└── bdd100k
+ ├── images
+ │ └── 10k
+ │ ├── test [2000 entries exceeds filelimit, not opening dir]
+ │ ├── train [7000 entries exceeds filelimit, not opening dir]
+ │ └── val [1000 entries exceeds filelimit, not opening dir]
+ └── labels
+ └── sem_seg
+ ├── colormaps
+ │ ├── train [7000 entries exceeds filelimit, not opening dir]
+ │ └── val [1000 entries exceeds filelimit, not opening dir]
+ ├── masks
+ │ ├── train [7000 entries exceeds filelimit, not opening dir]
+ │ └── val [1000 entries exceeds filelimit, not opening dir]
+ ├── polygons
+ │ ├── sem_seg_train.json
+ │ └── sem_seg_val.json
+ └── rles
+ ├── sem_seg_train.json
+ └── sem_seg_val.json
+```
+
+### Training commands
+
+```bash
+%cd mmsegmentation
+!python tools/train.py projects/bdd100k_dataset/configs/pspnet_r50-d8_4xb2-80k_bdd100k-512x1024.py\
+--work-dir your_work_dir
+```
+
+## Thanks
+
+- [\[Datasets\] Add Mapillary Vistas Datasets to MMSeg Core Package. #2576](https://github.com/open-mmlab/mmsegmentation/pull/2576/files)
+- [\[Feature\] Support CIHP dataset #1493](https://github.com/open-mmlab/mmsegmentation/pull/1493/files)
diff --git a/projects/bdd100k_dataset/configs/_base_/datasets/bdd100k.py b/projects/bdd100k_dataset/configs/_base_/datasets/bdd100k.py
new file mode 100644
index 00000000000..24cec69bfeb
--- /dev/null
+++ b/projects/bdd100k_dataset/configs/_base_/datasets/bdd100k.py
@@ -0,0 +1,70 @@
+# dataset settings
+dataset_type = 'BDD100KDataset'
+data_root = 'data/bdd100k/'
+
+crop_size = (512, 1024)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(
+ type='RandomResize',
+ scale=(2048, 1024),
+ ratio_range=(0.5, 2.0),
+ keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
+ # add loading annotation after ``Resize`` because ground truth
+ # does not need to do resize data transform
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
+tta_pipeline = [
+ dict(type='LoadImageFromFile', backend_args=None),
+ dict(
+ type='TestTimeAug',
+ transforms=[
+ [
+ dict(type='Resize', scale_factor=r, keep_ratio=True)
+ for r in img_ratios
+ ],
+ [
+ dict(type='RandomFlip', prob=0., direction='horizontal'),
+ dict(type='RandomFlip', prob=1., direction='horizontal')
+ ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
+ ])
+]
+train_dataloader = dict(
+ batch_size=2,
+ num_workers=2,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='images/10k/train',
+ seg_map_path='labels/sem_seg/masks/train'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='images/10k/val',
+ seg_map_path='labels/sem_seg/masks/val'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
diff --git a/projects/bdd100k_dataset/configs/pspnet_r50-d8_4xb2-80k_bdd100k-512x1024.py b/projects/bdd100k_dataset/configs/pspnet_r50-d8_4xb2-80k_bdd100k-512x1024.py
new file mode 100644
index 00000000000..456d4c79839
--- /dev/null
+++ b/projects/bdd100k_dataset/configs/pspnet_r50-d8_4xb2-80k_bdd100k-512x1024.py
@@ -0,0 +1,11 @@
+_base_ = [
+ '../../../configs/_base_/models/pspnet_r50-d8.py',
+ './_base_/datasets/bdd100k.py',
+ '../../../configs/_base_/default_runtime.py',
+ '../../../configs/_base_/schedules/schedule_80k.py'
+]
+custom_imports = dict(
+ imports=['projects.bdd100k_dataset.mmseg.datasets.bdd100k'])
+crop_size = (512, 1024)
+data_preprocessor = dict(size=crop_size)
+model = dict(data_preprocessor=data_preprocessor)
diff --git a/projects/bdd100k_dataset/docs/en/user_guides/2_dataset_prepare.md b/projects/bdd100k_dataset/docs/en/user_guides/2_dataset_prepare.md
new file mode 100644
index 00000000000..f2383cfcac2
--- /dev/null
+++ b/projects/bdd100k_dataset/docs/en/user_guides/2_dataset_prepare.md
@@ -0,0 +1,40 @@
+## BDD100K
+
+- You could download BDD100k datasets from [here](https://bdd-data.berkeley.edu/) after registration.
+
+- You can download images and masks by clicking `10K Images` button and `Segmentation` button.
+
+- After download, unzip by the following instructions:
+
+ ```bash
+ unzip ~/bdd100k_images_10k.zip -d ~/mmsegmentation/data/
+ unzip ~/bdd100k_sem_seg_labels_trainval.zip -d ~/mmsegmentation/data/
+ ```
+
+```none
+mmsegmentation
+├── mmseg
+├── tools
+├── configs
+├── data
+│ ├── bdd100k
+│ │ ├── images
+│ │ │ └── 10k
+| │ │ │ ├── test
+| │ │ │ ├── train
+| │ │ │ └── val
+│ │ └── labels
+│ │ │ └── sem_seg
+| │ │ │ ├── colormaps
+| │ │ │ │ ├──train
+| │ │ │ │ └──val
+| │ │ │ ├── masks
+| │ │ │ │ ├──train
+| │ │ │ │ └──val
+| │ │ │ ├── polygons
+| │ │ │ │ ├──sem_seg_train.json
+| │ │ │ │ └──sem_seg_val.json
+| │ │ │ └── rles
+| │ │ │ │ ├──sem_seg_train.json
+| │ │ │ │ └──sem_seg_val.json
+```
diff --git a/projects/bdd100k_dataset/docs/zh_cn/user_guides/2_dataset_prepare.md b/projects/bdd100k_dataset/docs/zh_cn/user_guides/2_dataset_prepare.md
new file mode 100644
index 00000000000..64fb763db4c
--- /dev/null
+++ b/projects/bdd100k_dataset/docs/zh_cn/user_guides/2_dataset_prepare.md
@@ -0,0 +1,42 @@
+## BDD100K
+
+- 可以从[官方网站](https://bdd-data.berkeley.edu/) 下载 BDD100K数据集(语义分割任务主要是10K数据集),按照官网要求注册并登陆后,数据可以在[这里](https://bdd-data.berkeley.edu/portal.html#download)找到。
+
+- 图像数据对应的名称是是`10K Images`, 语义分割标注对应的名称是`Segmentation`
+
+- 下载后,可以使用以下代码进行解压
+
+ ```bash
+ unzip ~/bdd100k_images_10k.zip -d ~/mmsegmentation/data/
+ unzip ~/bdd100k_sem_seg_labels_trainval.zip -d ~/mmsegmentation/data/
+ ```
+
+就可以得到以下文件结构了:
+
+```none
+mmsegmentation
+├── mmseg
+├── tools
+├── configs
+├── data
+│ ├── bdd100k
+│ │ ├── images
+│ │ │ └── 10k
+| │ │ │ ├── test
+| │ │ │ ├── train
+| │ │ │ └── val
+│ │ └── labels
+│ │ │ └── sem_seg
+| │ │ │ ├── colormaps
+| │ │ │ │ ├──train
+| │ │ │ │ └──val
+| │ │ │ ├── masks
+| │ │ │ │ ├──train
+| │ │ │ │ └──val
+| │ │ │ ├── polygons
+| │ │ │ │ ├──sem_seg_train.json
+| │ │ │ │ └──sem_seg_val.json
+| │ │ │ └── rles
+| │ │ │ │ ├──sem_seg_train.json
+| │ │ │ │ └──sem_seg_val.json
+```
diff --git a/projects/bdd100k_dataset/mmseg/datasets/bdd100k.py b/projects/bdd100k_dataset/mmseg/datasets/bdd100k.py
new file mode 100644
index 00000000000..e536de74614
--- /dev/null
+++ b/projects/bdd100k_dataset/mmseg/datasets/bdd100k.py
@@ -0,0 +1,31 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from mmseg.datasets.basesegdataset import BaseSegDataset
+
+# from mmseg.registry import DATASETS
+# @DATASETS.register_module()
+
+
+class BDD100KDataset(BaseSegDataset):
+ METAINFO = dict(
+ classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
+ 'traffic light', 'traffic sign', 'vegetation', 'terrain',
+ 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train',
+ 'motorcycle', 'bicycle'),
+ palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
+ [190, 153, 153], [153, 153, 153], [250, 170,
+ 30], [220, 220, 0],
+ [107, 142, 35], [152, 251, 152], [70, 130, 180],
+ [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
+ [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]])
+
+ def __init__(self,
+ img_suffix='.jpg',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
diff --git a/projects/example_project/README.md b/projects/example_project/README.md
index 4338b8acac1..e4fd03cf4a4 100644
--- a/projects/example_project/README.md
+++ b/projects/example_project/README.md
@@ -53,7 +53,7 @@ mim train mmsegmentation configs/fcn_dummy-r50-d8_4xb2-40k_cityscapes-512x1024.p
mim test mmsegmentation configs/fcn_dummy-r50-d8_4xb2-40k_cityscapes-512x1024.py --work-dir work_dirs/dummy_resnet --checkpoint ${CHECKPOINT_PATH}
```
-> List the results as usually done in other model's README. \[Example\](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/configs/fcn#results-and-models
+> List the results as usually done in other model's README. \[Example\](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/fcn#results-and-models
> You should claim whether this is based on the pre-trained weights, which are converted from the official release; or it's a reproduced result obtained from retraining the model in this project
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
@@ -113,11 +113,11 @@ Here is a checklist illustrating a usual development workflow of a successful pr
- [ ] Type hints and docstrings
-> Ideally *all* the methods should have [type hints](https://www.pythontutorial.net/python-basics/python-type-hints/) and [docstrings](https://google.github.io/styleguide/pyguide.html#381-docstrings). [Example](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/utils/io.py#L9)
+> Ideally *all* the methods should have [type hints](https://www.pythontutorial.net/python-basics/python-type-hints/) and [docstrings](https://google.github.io/styleguide/pyguide.html#381-docstrings). [Example](https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/utils/io.py#L9)
- [ ] Unit tests
-> Unit tests for each module are required. [Example](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/tests/test_utils/test_io.py#L14)
+> Unit tests for each module are required. [Example](https://github.com/open-mmlab/mmsegmentation/blob/main/tests/test_utils/test_io.py#L14)
- [ ] Code polishing
@@ -125,10 +125,10 @@ Here is a checklist illustrating a usual development workflow of a successful pr
- [ ] Metafile.yml
-> It will be parsed by MIM and Inferencer. [Example](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/fcn/fcn.yml)
+> It will be parsed by MIM and Inferencer. [Example](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/fcn/fcn.yml)
- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
-> In particular, you may have to refactor this README into a standard one. [Example](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/fcn/README.md)
+> In particular, you may have to refactor this README into a standard one. [Example](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/fcn/README.md)
- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/faq.md b/projects/faq.md
index 724c1cf6a5d..74b292b050c 100644
--- a/projects/faq.md
+++ b/projects/faq.md
@@ -1,6 +1,6 @@
Q1: Why set up `projects/` folder?
-Implementing new models and features into OpenMMLab's algorithm libraries could be troublesome due to the rigorous requirements on code quality, which could hinder the fast iteration of SOTA models and might discourage our members from sharing their latest outcomes here. And that's why we have this `projects/` folder now, where some experimental features, frameworks and models are placed, only needed to satisfy the minimum requirement on the code quality, and can be used as standalone libraries. Users are welcome to use them if they [use MMSegmentation from source](https://mmsegmentation.readthedocs.io/en/dev-1.x/get_started.html#best-practices).
+Implementing new models and features into OpenMMLab's algorithm libraries could be troublesome due to the rigorous requirements on code quality, which could hinder the fast iteration of SOTA models and might discourage our members from sharing their latest outcomes here. And that's why we have this `projects/` folder now, where some experimental features, frameworks and models are placed, only needed to satisfy the minimum requirement on the code quality, and can be used as standalone libraries. Users are welcome to use them if they [use MMSegmentation from source](https://mmsegmentation.readthedocs.io/en/latest/get_started.html#best-practices).
Q2: Why should there be a checklist for a project?
diff --git a/projects/gid_dataset/configs/_base_/datasets/gid.py b/projects/gid_dataset/configs/_base_/datasets/gid.py
new file mode 100644
index 00000000000..f7218105f2f
--- /dev/null
+++ b/projects/gid_dataset/configs/_base_/datasets/gid.py
@@ -0,0 +1,67 @@
+# dataset settings
+dataset_type = 'GID_Dataset' # 注册的类名
+data_root = 'data/gid/' # 数据集根目录
+crop_size = (256, 256) # 图像裁剪大小
+train_pipeline = [
+ dict(type='LoadImageFromFile'), # 从文件中加载图像
+ dict(type='LoadAnnotations'), # 从文件中加载标注
+ dict(
+ type='RandomResize', # 随机缩放
+ scale=(512, 512), # 缩放尺寸
+ ratio_range=(0.5, 2.0), # 缩放比例范围
+ keep_ratio=True), # 是否保持长宽比
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), # 随机裁剪
+ dict(type='RandomFlip', prob=0.5), # 随机翻转
+ dict(type='PhotoMetricDistortion'), # 图像增强
+ dict(type='PackSegInputs') # 打包数据
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'), # 从文件中加载图像
+ dict(type='Resize', scale=(256, 256), keep_ratio=True), # 缩放
+ # add loading annotation after ``Resize`` because ground truth
+ # does not need to do resize data transform
+ dict(type='LoadAnnotations'), # 从文件中加载标注
+ dict(type='PackSegInputs') # 打包数据
+]
+img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # 多尺度预测缩放比例
+tta_pipeline = [ # 多尺度测试
+ dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
+ dict(
+ type='TestTimeAug',
+ transforms=[
+ [
+ dict(type='Resize', scale_factor=r, keep_ratio=True)
+ for r in img_ratios
+ ],
+ [
+ dict(type='RandomFlip', prob=0., direction='horizontal'),
+ dict(type='RandomFlip', prob=1., direction='horizontal')
+ ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
+ ])
+]
+train_dataloader = dict( # 训练数据加载器
+ batch_size=2, # 训练时的数据批量大小
+ num_workers=4, # 数据加载线程数
+ persistent_workers=True, # 是否持久化线程
+ sampler=dict(type='InfiniteSampler', shuffle=True), # 无限采样器
+ dataset=dict(
+ type=dataset_type, # 数据集类名
+ data_root=data_root, # 数据集根目录
+ data_prefix=dict(
+ img_path='img_dir/train',
+ seg_map_path='ann_dir/train'), # 训练集图像和标注路径
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1, # 验证时的数据批量大小
+ num_workers=4, # 数据加载线程数
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(img_path='img_dir/val', seg_map_path='ann_dir/val'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
diff --git a/projects/gid_dataset/configs/deeplabv3plus_r101-d8_4xb2-240k_gid-256x256.py b/projects/gid_dataset/configs/deeplabv3plus_r101-d8_4xb2-240k_gid-256x256.py
new file mode 100644
index 00000000000..70cb6005f81
--- /dev/null
+++ b/projects/gid_dataset/configs/deeplabv3plus_r101-d8_4xb2-240k_gid-256x256.py
@@ -0,0 +1,15 @@
+_base_ = [
+ '../../../configs/_base_/models/deeplabv3plus_r50-d8.py',
+ './_base_/datasets/gid.py', '../../../configs/_base_/default_runtime.py',
+ '../../../configs/_base_/schedules/schedule_240k.py'
+]
+custom_imports = dict(imports=['projects.gid_dataset.mmseg.datasets.gid'])
+
+crop_size = (256, 256)
+data_preprocessor = dict(size=crop_size)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ pretrained='open-mmlab://resnet101_v1c',
+ backbone=dict(depth=101),
+ decode_head=dict(num_classes=6),
+ auxiliary_head=dict(num_classes=6))
diff --git a/projects/gid_dataset/mmseg/datasets/gid.py b/projects/gid_dataset/mmseg/datasets/gid.py
new file mode 100644
index 00000000000..a9e8c510b46
--- /dev/null
+++ b/projects/gid_dataset/mmseg/datasets/gid.py
@@ -0,0 +1,55 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmseg.datasets.basesegdataset import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+# 注册数据集类
+@DATASETS.register_module()
+class GID_Dataset(BaseSegDataset):
+ """Gaofen Image Dataset (GID)
+
+ Dataset paper link:
+ https://www.sciencedirect.com/science/article/pii/S0034425719303414
+ https://x-ytong.github.io/project/GID.html
+
+ GID 6 classes: others, built-up, farmland, forest, meadow, water
+
+ In this example, select 15 images from GID dataset as training set,
+ and select 5 images as validation set.
+ The selected images are listed as follows:
+
+ GF2_PMS1__L1A0000647767-MSS1
+ GF2_PMS1__L1A0001064454-MSS1
+ GF2_PMS1__L1A0001348919-MSS1
+ GF2_PMS1__L1A0001680851-MSS1
+ GF2_PMS1__L1A0001680853-MSS1
+ GF2_PMS1__L1A0001680857-MSS1
+ GF2_PMS1__L1A0001757429-MSS1
+ GF2_PMS2__L1A0000607681-MSS2
+ GF2_PMS2__L1A0000635115-MSS2
+ GF2_PMS2__L1A0000658637-MSS2
+ GF2_PMS2__L1A0001206072-MSS2
+ GF2_PMS2__L1A0001471436-MSS2
+ GF2_PMS2__L1A0001642620-MSS2
+ GF2_PMS2__L1A0001787089-MSS2
+ GF2_PMS2__L1A0001838560-MSS2
+
+ The ``img_suffix`` is fixed to '.tif' and ``seg_map_suffix`` is
+ fixed to '.tif' for GID.
+ """
+ METAINFO = dict(
+ classes=('Others', 'Built-up', 'Farmland', 'Forest', 'Meadow',
+ 'Water'),
+ palette=[[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 255, 255],
+ [255, 255, 0], [0, 0, 255]])
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=None,
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
diff --git a/projects/gid_dataset/tools/dataset_converters/gid.py b/projects/gid_dataset/tools/dataset_converters/gid.py
new file mode 100644
index 00000000000..d95654aa14b
--- /dev/null
+++ b/projects/gid_dataset/tools/dataset_converters/gid.py
@@ -0,0 +1,181 @@
+import argparse
+import glob
+import math
+import os
+import os.path as osp
+
+import mmcv
+import numpy as np
+from mmengine.utils import ProgressBar, mkdir_or_exist
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description='Convert GID dataset to mmsegmentation format')
+ parser.add_argument('dataset_img_path', help='GID images folder path')
+ parser.add_argument('dataset_label_path', help='GID labels folder path')
+ parser.add_argument('--tmp_dir', help='path of the temporary directory')
+ parser.add_argument(
+ '-o', '--out_dir', help='output path', default='data/gid')
+ parser.add_argument(
+ '--clip_size',
+ type=int,
+ help='clipped size of image after preparation',
+ default=256)
+ parser.add_argument(
+ '--stride_size',
+ type=int,
+ help='stride of clipping original images',
+ default=256)
+ args = parser.parse_args()
+ return args
+
+
+GID_COLORMAP = dict(
+ Background=(0, 0, 0), # 0-背景-黑色
+ Building=(255, 0, 0), # 1-建筑-红色
+ Farmland=(0, 255, 0), # 2-农田-绿色
+ Forest=(0, 0, 255), # 3-森林-蓝色
+ Meadow=(255, 255, 0), # 4-草地-黄色
+ Water=(0, 0, 255) # 5-水-蓝色
+)
+palette = list(GID_COLORMAP.values())
+classes = list(GID_COLORMAP.keys())
+
+
+# 用列表来存一个 RGB 和一个类别的对应
+def colormap2label(palette):
+ colormap2label_list = np.zeros(256**3, dtype=np.longlong)
+ for i, colormap in enumerate(palette):
+ colormap2label_list[(colormap[0] * 256 + colormap[1]) * 256 +
+ colormap[2]] = i
+ return colormap2label_list
+
+
+# 给定那个列表,和vis_png然后生成masks_png
+def label_indices(RGB_label, colormap2label_list):
+ RGB_label = RGB_label.astype('int32')
+ idx = (RGB_label[:, :, 0] * 256 +
+ RGB_label[:, :, 1]) * 256 + RGB_label[:, :, 2]
+ return colormap2label_list[idx]
+
+
+def RGB2mask(RGB_label, colormap2label_list):
+ mask_label = label_indices(RGB_label, colormap2label_list)
+ return mask_label
+
+
+colormap2label_list = colormap2label(palette)
+
+
+def clip_big_image(image_path, clip_save_dir, args, to_label=False):
+ """Original image of GID dataset is very large, thus pre-processing of them
+ is adopted.
+
+ Given fixed clip size and stride size to generate
+ clipped image, the intersection of width and height is determined.
+ For example, given one 6800 x 7200 original image, the clip size is
+ 256 and stride size is 256, thus it would generate 29 x 27 = 783 images
+ whose size are all 256 x 256.
+ """
+
+ image = mmcv.imread(image_path, channel_order='rgb')
+ # image = mmcv.bgr2gray(image)
+
+ h, w, c = image.shape
+ clip_size = args.clip_size
+ stride_size = args.stride_size
+
+ num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil(
+ (h - clip_size) /
+ stride_size) * stride_size + clip_size >= h else math.ceil(
+ (h - clip_size) / stride_size) + 1
+ num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil(
+ (w - clip_size) /
+ stride_size) * stride_size + clip_size >= w else math.ceil(
+ (w - clip_size) / stride_size) + 1
+
+ x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
+ xmin = x * clip_size
+ ymin = y * clip_size
+
+ xmin = xmin.ravel()
+ ymin = ymin.ravel()
+ xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size,
+ np.zeros_like(xmin))
+ ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size,
+ np.zeros_like(ymin))
+ boxes = np.stack([
+ xmin + xmin_offset, ymin + ymin_offset,
+ np.minimum(xmin + clip_size, w),
+ np.minimum(ymin + clip_size, h)
+ ],
+ axis=1)
+
+ if to_label:
+ image = RGB2mask(image, colormap2label_list)
+
+ for count, box in enumerate(boxes):
+ start_x, start_y, end_x, end_y = box
+ clipped_image = image[start_y:end_y,
+ start_x:end_x] if to_label else image[
+ start_y:end_y, start_x:end_x, :]
+ img_name = osp.basename(image_path).replace('.tif', '')
+ img_name = img_name.replace('_label', '')
+ if count % 3 == 0:
+ mmcv.imwrite(
+ clipped_image.astype(np.uint8),
+ osp.join(
+ clip_save_dir.replace('train', 'val'),
+ f'{img_name}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
+ else:
+ mmcv.imwrite(
+ clipped_image.astype(np.uint8),
+ osp.join(
+ clip_save_dir,
+ f'{img_name}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
+ count += 1
+
+
+def main():
+ args = parse_args()
+ """
+ According to this paper: https://ieeexplore.ieee.org/document/9343296/
+ select 15 images contained in GID, , which cover the whole six
+ categories, to generate train set and validation set.
+
+ """
+
+ if args.out_dir is None:
+ out_dir = osp.join('data', 'gid')
+ else:
+ out_dir = args.out_dir
+
+ print('Making directories...')
+ mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
+ mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
+ mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
+ mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
+
+ src_path_list = glob.glob(os.path.join(args.dataset_img_path, '*.tif'))
+ print(f'Find {len(src_path_list)} pictures')
+
+ prog_bar = ProgressBar(len(src_path_list))
+
+ dst_img_dir = osp.join(out_dir, 'img_dir', 'train')
+ dst_label_dir = osp.join(out_dir, 'ann_dir', 'train')
+
+ for i, img_path in enumerate(src_path_list):
+ label_path = osp.join(
+ args.dataset_label_path,
+ osp.basename(img_path.replace('.tif', '_label.tif')))
+
+ clip_big_image(img_path, dst_img_dir, args, to_label=False)
+ clip_big_image(label_path, dst_label_dir, args, to_label=True)
+ prog_bar.update()
+
+ print('Done!')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/projects/gid_dataset/tools/dataset_converters/gid_select15imgFromAll.py b/projects/gid_dataset/tools/dataset_converters/gid_select15imgFromAll.py
new file mode 100644
index 00000000000..d3eeff26902
--- /dev/null
+++ b/projects/gid_dataset/tools/dataset_converters/gid_select15imgFromAll.py
@@ -0,0 +1,75 @@
+import argparse
+import os
+import shutil
+
+# select 15 images from GID dataset
+
+img_list = [
+ 'GF2_PMS1__L1A0000647767-MSS1.tif', 'GF2_PMS1__L1A0001064454-MSS1.tif',
+ 'GF2_PMS1__L1A0001348919-MSS1.tif', 'GF2_PMS1__L1A0001680851-MSS1.tif',
+ 'GF2_PMS1__L1A0001680853-MSS1.tif', 'GF2_PMS1__L1A0001680857-MSS1.tif',
+ 'GF2_PMS1__L1A0001757429-MSS1.tif', 'GF2_PMS2__L1A0000607681-MSS2.tif',
+ 'GF2_PMS2__L1A0000635115-MSS2.tif', 'GF2_PMS2__L1A0000658637-MSS2.tif',
+ 'GF2_PMS2__L1A0001206072-MSS2.tif', 'GF2_PMS2__L1A0001471436-MSS2.tif',
+ 'GF2_PMS2__L1A0001642620-MSS2.tif', 'GF2_PMS2__L1A0001787089-MSS2.tif',
+ 'GF2_PMS2__L1A0001838560-MSS2.tif'
+]
+
+labels_list = [
+ 'GF2_PMS1__L1A0000647767-MSS1_label.tif',
+ 'GF2_PMS1__L1A0001064454-MSS1_label.tif',
+ 'GF2_PMS1__L1A0001348919-MSS1_label.tif',
+ 'GF2_PMS1__L1A0001680851-MSS1_label.tif',
+ 'GF2_PMS1__L1A0001680853-MSS1_label.tif',
+ 'GF2_PMS1__L1A0001680857-MSS1_label.tif',
+ 'GF2_PMS1__L1A0001757429-MSS1_label.tif',
+ 'GF2_PMS2__L1A0000607681-MSS2_label.tif',
+ 'GF2_PMS2__L1A0000635115-MSS2_label.tif',
+ 'GF2_PMS2__L1A0000658637-MSS2_label.tif',
+ 'GF2_PMS2__L1A0001206072-MSS2_label.tif',
+ 'GF2_PMS2__L1A0001471436-MSS2_label.tif',
+ 'GF2_PMS2__L1A0001642620-MSS2_label.tif',
+ 'GF2_PMS2__L1A0001787089-MSS2_label.tif',
+ 'GF2_PMS2__L1A0001838560-MSS2_label.tif'
+]
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description='From 150 images of GID dataset to select 15 images')
+ parser.add_argument('dataset_img_dir', help='150 GID images folder path')
+ parser.add_argument('dataset_label_dir', help='150 GID labels folder path')
+
+ parser.add_argument('dest_img_dir', help='15 GID images folder path')
+ parser.add_argument('dest_label_dir', help='15 GID labels folder path')
+
+ args = parser.parse_args()
+
+ return args
+
+
+def main():
+ """This script is used to select 15 images from GID dataset, According to
+ paper: https://ieeexplore.ieee.org/document/9343296/"""
+ args = parse_args()
+
+ img_path = args.dataset_img_dir
+ label_path = args.dataset_label_dir
+
+ dest_img_dir = args.dest_img_dir
+ dest_label_dir = args.dest_label_dir
+
+ # copy images of 'img_list' to 'desr_dir'
+ print('Copy images of img_list to desr_dir ing...')
+ for img in img_list:
+ shutil.copy(os.path.join(img_path, img), dest_img_dir)
+ print('Done!')
+
+ print('copy labels of labels_list to desr_dir ing...')
+ for label in labels_list:
+ shutil.copy(os.path.join(label_path, label), dest_label_dir)
+ print('Done!')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/projects/gid_dataset/user_guides/2_dataset_prepare.md b/projects/gid_dataset/user_guides/2_dataset_prepare.md
new file mode 100644
index 00000000000..63bd4d46fc5
--- /dev/null
+++ b/projects/gid_dataset/user_guides/2_dataset_prepare.md
@@ -0,0 +1,53 @@
+## Gaofen Image Dataset (GID)
+
+- GID 数据集可在[此处](https://x-ytong.github.io/project/GID.html)进行下载。
+- GID 数据集包含 150 张 6800x7200 的大尺寸图像,标签为 RGB 标签。
+- 根据[文献](https://ieeexplore.ieee.org/document/9343296/),此处选择 15 张图像生成训练集和验证集,该 15 张图像包含了所有六类信息。所选的图像名称如下:
+
+```None
+ GF2_PMS1__L1A0000647767-MSS1
+ GF2_PMS1__L1A0001064454-MSS1
+ GF2_PMS1__L1A0001348919-MSS1
+ GF2_PMS1__L1A0001680851-MSS1
+ GF2_PMS1__L1A0001680853-MSS1
+ GF2_PMS1__L1A0001680857-MSS1
+ GF2_PMS1__L1A0001757429-MSS1
+ GF2_PMS2__L1A0000607681-MSS2
+ GF2_PMS2__L1A0000635115-MSS2
+ GF2_PMS2__L1A0000658637-MSS2
+ GF2_PMS2__L1A0001206072-MSS2
+ GF2_PMS2__L1A0001471436-MSS2
+ GF2_PMS2__L1A0001642620-MSS2
+ GF2_PMS2__L1A0001787089-MSS2
+ GF2_PMS2__L1A0001838560-MSS2
+```
+
+这里也提供了一个脚本来方便的筛选出15张图像,
+
+```
+python projects/gid_dataset/tools/dataset_converters/gid_select15imgFromAll.py {150 张图像的路径} {150 张标签的路径} {15 张图像的路径} {15 张标签的路径}
+```
+
+在选择出 15 张图像后,执行以下命令进行裁切及标签的转换,需要修改为您所存储 15 张图像及标签的路径。
+
+```
+python projects/gid_dataset/tools/dataset_converters/gid.py {15 张图像的路径} {15 张标签的路径}
+```
+
+完成裁切后的 GID 数据结构如下:
+
+```none
+mmsegmentation
+├── mmseg
+├── tools
+├── configs
+├── data
+│ ├── gid
+│ │ ├── ann_dir
+| │ │ │ ├── train
+| │ │ │ ├── val
+│ │ ├── img_dir
+| │ │ │ ├── train
+| │ │ │ ├── val
+
+```
diff --git a/projects/hssn/README.md b/projects/hssn/README.md
index c2a74c69f9a..9dcbf37de0c 100644
--- a/projects/hssn/README.md
+++ b/projects/hssn/README.md
@@ -41,9 +41,9 @@ bash tools/dist_test.sh projects/hssn/configs/hssn/hieraseg_deeplabv3plus_r101-d
### Cityscapes
-| Method | Backbone | Crop Size | mIoU | mIoU (ms+flip) | config | model |
-| :--------: | :------: | :-------: | :---: | :------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------: |
-| DeeplabV3+ | R-101-D8 | 512x1024 | 81.61 | 82.71 | [config](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/projects/HieraSeg/configs/hieraseg/hieraseg_deeplabv3plus_r101-d8_4xb2-80l_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hieraseg/hieraseg_deeplabv3plus_r101-d8_4xb2-80k_cityscapes-512x1024_20230112_125023-bc59a3d1.pth) |
+| Method | Backbone | Crop Size | mIoU | mIoU (ms+flip) | config | model |
+| :--------: | :------: | :-------: | :---: | :------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| DeeplabV3+ | R-101-D8 | 512x1024 | 81.61 | 82.71 | [config](https://github.com/open-mmlab/mmsegmentation/tree/main/projects/HieraSeg/configs/hieraseg/hieraseg_deeplabv3plus_r101-d8_4xb2-80l_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hieraseg/hieraseg_deeplabv3plus_r101-d8_4xb2-80k_cityscapes-512x1024_20230112_125023-bc59a3d1.pth) |
diff --git a/projects/isnet/README.md b/projects/isnet/README.md
index 3a3172a9d99..0a79ad6a4fa 100644
--- a/projects/isnet/README.md
+++ b/projects/isnet/README.md
@@ -96,11 +96,11 @@ A project does not necessarily have to be finished in a single PR, but it's esse
- [ ] Type hints and docstrings
-
+
- [ ] Unit tests
-
+
- [ ] Code polishing
@@ -108,10 +108,10 @@ A project does not necessarily have to be finished in a single PR, but it's esse
- [ ] Metafile.yml
-
+
- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
-
+
- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/mapillary_dataset/README.md b/projects/mapillary_dataset/README.md
index cdc61d53a93..44a1e33ef93 100644
--- a/projects/mapillary_dataset/README.md
+++ b/projects/mapillary_dataset/README.md
@@ -10,7 +10,7 @@ This project implements **`Mapillary Vistas Dataset`**
### Dataset preparing
-Preparing `Mapillary Vistas Dataset` dataset following [Mapillary Vistas Dataset Preparing Guide](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/projects/mapillary_dataset/docs/en/user_guides/2_dataset_prepare.md)
+Preparing `Mapillary Vistas Dataset` dataset following [Mapillary Vistas Dataset Preparing Guide](https://github.com/open-mmlab/mmsegmentation/tree/main/projects/mapillary_dataset/docs/en/user_guides/2_dataset_prepare.md)
```none
mmsegmentation
diff --git a/projects/mapillary_dataset/docs/en/user_guides/2_dataset_prepare.md b/projects/mapillary_dataset/docs/en/user_guides/2_dataset_prepare.md
index fa074543300..c5cbc0f9b80 100644
--- a/projects/mapillary_dataset/docs/en/user_guides/2_dataset_prepare.md
+++ b/projects/mapillary_dataset/docs/en/user_guides/2_dataset_prepare.md
@@ -47,7 +47,7 @@
```
- You could set Datasets version with `MapillaryDataset_v1` and `MapillaryDataset_v2` in your configs.
- View the Mapillary Vistas Datasets config file here [V1.2](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/_base_/datasets/mapillary_v1.py) and [V2.0](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/_base_/datasets/mapillary_v2.py)
+ View the Mapillary Vistas Datasets config file here [V1.2](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/_base_/datasets/mapillary_v1.py) and [V2.0](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/_base_/datasets/mapillary_v2.py)
- **View datasets labels index and palette**
diff --git a/projects/medical/2d_image/ct/cranium/README.md b/projects/medical/2d_image/ct/cranium/README.md
new file mode 100644
index 00000000000..d3fa64ea406
--- /dev/null
+++ b/projects/medical/2d_image/ct/cranium/README.md
@@ -0,0 +1,142 @@
+# Brain CT Images with Intracranial Hemorrhage Masks (Cranium)
+
+## Description
+
+This project supports **`Brain CT Images with Intracranial Hemorrhage Masks (Cranium)`**, which can be downloaded from [here](https://www.kaggle.com/datasets/vbookshelf/computed-tomography-ct-images).
+
+### Dataset Overview
+
+This dataset consists of head CT (Computed Thomography) images in jpg format. There are 2500 brain window images and 2500 bone window images, for 82 patients. There are approximately 30 image slices per patient. 318 images have associated intracranial image masks. Also included are csv files containing hemorrhage diagnosis data and patient data.
+This is version 1.0.0 of this dataset. A full description of this dataset as well as updated versions can be found here:
+https://physionet.org/content/ct-ich/1.0.0/
+
+### Statistic Information
+
+| Dataset Name | Anatomical Region | Task Type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
+| ----------------------------------------------------------------------------------- | ----------------- | ------------ | -------- | ------------ | --------------------- | ---------------------- | ------------ | --------------------------------------------------------- |
+| [Cranium](https://www.kaggle.com/datasets/vbookshelf/computed-tomography-ct-images) | head_and_neck | segmentation | ct | 2 | 2501/-/- | yes/-/- | 2020 | [CC-BY 4.0](https://creativecommons.org/licenses/by/4.0/) |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 2501 | 99.93 | - | - | - | - |
+| hemorrhage | 318 | 0.07 | - | - | - | - |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![cranium](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/ct/cranium/cranium_dataset.png?raw=true)
+
+## Dataset Citation
+
+```
+@article{hssayeni2020computed,
+ title={Computed tomography images for intracranial hemorrhage detection and segmentation},
+ author={Hssayeni, Murtadha and Croock, MS and Salman, AD and Al-khafaji, HF and Yahya, ZA and Ghoraani, B},
+ journal={Intracranial Hemorrhage Segmentation Using A Deep Convolutional Model. Data},
+ volume={5},
+ number={1},
+ pages={179},
+ year={2020}
+}
+```
+
+### Prerequisites
+
+- Python v3.8
+- PyTorch v1.10.0
+- pillow(PIL) v9.3.0 9.3.0
+- scikit-learn(sklearn) v1.2.0 1.2.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `cranium/` root directory, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset Preparing
+
+- download dataset from [here](https://www.kaggle.com/datasets/vbookshelf/computed-tomography-ct-images) and decompress data to path `'data/'`.
+- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below.
+- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set cannot be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── ct
+ │ │ │ │ ├── cranium
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── val.txt
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+```
+
+### Divided Dataset Information
+
+***Note: The table information below is divided by ourselves.***
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 2000 | 99.93 | 501 | 99.92 | - | - |
+| hemorrhage | 260 | 0.07 | 260 | 0.08 | - | - |
+
+### Training commands
+
+To train models on a single server with one GPU. (default)
+
+```shell
+mim train mmseg ./configs/${CONFIG_FILE}
+```
+
+### Testing commands
+
+To test models on a single server with one GPU. (default)
+
+```shell
+mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH}
+```
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+ - [x] Basic docstrings & proper citation
+ - [ ] Test-time correctness
+ - [x] A full README
+
+- [ ] Milestone 2: Indicates a successful model implementation.
+
+ - [ ] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+ - [ ] Unit tests
+ - [ ] Code polishing
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/ct/cranium/configs/cranium_512x512.py b/projects/medical/2d_image/ct/cranium/configs/cranium_512x512.py
new file mode 100644
index 00000000000..d9b44362a5c
--- /dev/null
+++ b/projects/medical/2d_image/ct/cranium/configs/cranium_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'CraniumDataset'
+data_root = 'data/'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='val.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/ct/cranium/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_cranium-512x512.py b/projects/medical/2d_image/ct/cranium/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_cranium-512x512.py
new file mode 100644
index 00000000000..ac013a215ae
--- /dev/null
+++ b/projects/medical/2d_image/ct/cranium/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_cranium-512x512.py
@@ -0,0 +1,18 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './cranium_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.cranium_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(
+ num_classes=2, loss_decode=dict(use_sigmoid=True), out_channels=1),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/ct/cranium/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_cranium-512x512.py b/projects/medical/2d_image/ct/cranium/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_cranium-512x512.py
new file mode 100644
index 00000000000..c71110a21f7
--- /dev/null
+++ b/projects/medical/2d_image/ct/cranium/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_cranium-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './cranium_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.cranium_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/ct/cranium/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_cranium-512x512.py b/projects/medical/2d_image/ct/cranium/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_cranium-512x512.py
new file mode 100644
index 00000000000..abbdac285b2
--- /dev/null
+++ b/projects/medical/2d_image/ct/cranium/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_cranium-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './cranium_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.cranium_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/ct/cranium/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_cranium-512x512.py b/projects/medical/2d_image/ct/cranium/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_cranium-512x512.py
new file mode 100644
index 00000000000..418595268f9
--- /dev/null
+++ b/projects/medical/2d_image/ct/cranium/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_cranium-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './cranium_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.cranium_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/ct/cranium/datasets/cranium_dataset.py b/projects/medical/2d_image/ct/cranium/datasets/cranium_dataset.py
new file mode 100644
index 00000000000..d65f1cbfc68
--- /dev/null
+++ b/projects/medical/2d_image/ct/cranium/datasets/cranium_dataset.py
@@ -0,0 +1,31 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class CraniumDataset(BaseSegDataset):
+ """CraniumDataset dataset.
+
+ In segmentation map annotation for CraniumDataset,
+ 0 stands for background, which is included in 2 categories.
+ ``reduce_zero_label`` is fixed to False. The ``img_suffix``
+ is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
+
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
+ Default to False.
+ """
+ METAINFO = dict(classes=('background', 'hemorrhage'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
diff --git a/projects/medical/2d_image/ct/cranium/tools/prepare_dataset.py b/projects/medical/2d_image/ct/cranium/tools/prepare_dataset.py
new file mode 100644
index 00000000000..1aa4e435614
--- /dev/null
+++ b/projects/medical/2d_image/ct/cranium/tools/prepare_dataset.py
@@ -0,0 +1,66 @@
+import os
+
+import numpy as np
+from PIL import Image
+
+root_path = 'data/'
+img_suffix = '.png'
+seg_map_suffix = '.png'
+save_img_suffix = '.png'
+save_seg_map_suffix = '.png'
+tgt_img_dir = os.path.join(root_path, 'images/train/')
+tgt_mask_dir = os.path.join(root_path, 'masks/train/')
+os.system('mkdir -p ' + tgt_img_dir)
+os.system('mkdir -p ' + tgt_mask_dir)
+
+
+def read_single_array_from_pil(path):
+ return np.asarray(Image.open(path))
+
+
+def save_png_from_array(arr, save_path, mode=None):
+ Image.fromarray(arr, mode=mode).save(save_path)
+
+
+def convert_label(img, convert_dict):
+ arr = np.zeros_like(img, dtype=np.uint8)
+ for c, i in convert_dict.items():
+ arr[img == c] = i
+ return arr
+
+
+patients_dir = os.path.join(
+ root_path, 'Cranium/computed-tomography-images-for-' +
+ 'intracranial-hemorrhage-detection-and-segmentation-1.0.0' +
+ '/Patients_CT')
+
+patients = sorted(os.listdir(patients_dir))
+for p in patients:
+ data_dir = os.path.join(patients_dir, p, 'brain')
+ file_names = os.listdir(data_dir)
+ img_w_mask_names = [
+ _.replace('_HGE_Seg', '') for _ in file_names if 'Seg' in _
+ ]
+ img_wo_mask_names = [
+ _ for _ in file_names if _ not in img_w_mask_names and 'Seg' not in _
+ ]
+
+ for file_name in file_names:
+ path = os.path.join(data_dir, file_name)
+ img = read_single_array_from_pil(path)
+ tgt_name = file_name.replace('.jpg', img_suffix)
+ tgt_name = p + '_' + tgt_name
+ if 'Seg' in file_name: # is a mask
+ tgt_name = tgt_name.replace('_HGE_Seg', '')
+ mask_path = os.path.join(tgt_mask_dir, tgt_name)
+ mask = convert_label(img, convert_dict={0: 0, 255: 1})
+ save_png_from_array(mask, mask_path)
+ else:
+ img_path = os.path.join(tgt_img_dir, tgt_name)
+ pil = Image.fromarray(img).convert('RGB')
+ pil.save(img_path)
+
+ if file_name in img_wo_mask_names:
+ mask = np.zeros_like(img, dtype=np.uint8)
+ mask_path = os.path.join(tgt_mask_dir, tgt_name)
+ save_png_from_array(mask, mask_path)
diff --git a/projects/medical/2d_image/dermoscopy/isic2016_task1/README.md b/projects/medical/2d_image/dermoscopy/isic2016_task1/README.md
new file mode 100644
index 00000000000..6e44e415ed6
--- /dev/null
+++ b/projects/medical/2d_image/dermoscopy/isic2016_task1/README.md
@@ -0,0 +1,149 @@
+# ISIC-2016 Task1
+
+## Description
+
+This project support **`ISIC-2016 Task1 `**, and the dataset used in this project can be downloaded from [here](https://challenge.isic-archive.com/data/#2016).
+
+### Dataset Overview
+
+The overarching goal of the challenge is to develop image analysis tools to enable the automated diagnosis of melanoma from dermoscopic images.
+
+This challenge provides training data (~900 images) for participants to engage in all 3 components of lesion image analysis. A separate test dataset (~350 images) will be provided for participants to generate and submit automated results.
+
+### Original Statistic Information
+
+| Dataset name | Anatomical region | Task type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
+| ---------------------------------------------------------------- | ----------------- | ------------ | ---------- | ------------ | --------------------- | ---------------------- | ------------ | ---------------------------------------------------------------------- |
+| [ISIC-2016 Task1](https://challenge.isic-archive.com/data/#2016) | full body | segmentation | dermoscopy | 2 | 900/-/379- | yes/-/yes | 2016 | [CC-0](https://creativecommons.org/share-your-work/public-domain/cc0/) |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :---------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 900 | 82.08 | - | - | 379 | 81.98 |
+| skin lesion | 900 | 17.92 | - | - | 379 | 18.02 |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![bac](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/dermoscopy/isic2016_task1/isic2016_task1.png)
+
+### Prerequisites
+
+- Python 3.8
+- PyTorch 1.10.0
+- pillow(PIL) 9.3.0
+- scikit-learn(sklearn) 1.2.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of PYTHONPATH, which should point to the project's directory so that Python can locate the module files. In isic2016_task1/ root directory, run the following line to add the current directory to PYTHONPATH:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset preparing
+
+- download dataset from [here](https://challenge.isic-archive.com/data/#2016) and decompression data to path 'data/'.
+- run script `"python tools/prepare_dataset.py"` to split dataset and change folder structure as below.
+- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt` and `test.txt`. If the label of official validation set and test set can't be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── dermoscopy
+ │ │ │ │ ├── isic2016_task1
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── test.txt
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ │ ├── test
+ │ │ │ │ | │ │ │ ├── yyy.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── yyy.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ │ ├── test
+ │ │ │ │ | │ │ │ ├── yyy.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── yyy.png
+```
+
+### Training commands
+
+```shell
+mim train mmseg ./configs/${CONFIG_PATH}
+```
+
+To train on multiple GPUs, e.g. 8 GPUs, run the following command:
+
+```shell
+mim train mmseg ./configs/${CONFIG_PATH} --launcher pytorch --gpus 8
+```
+
+### Testing commands
+
+```shell
+mim test mmseg ./configs/${CONFIG_PATH} --checkpoint ${CHECKPOINT_PATH}
+```
+
+
+
+## Results
+
+### ISIC-2016 Task1
+
+| Method | Backbone | Crop Size | lr | mIoU | mDice | config |
+| :-------------: | :------: | :-------: | :----: | :--: | :---: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| fcn_unet_s5-d16 | unet | 512x512 | 0.01 | - | - | [config](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/projects/medical/2d_image/dermoscopy/isic2016_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_isic2016-task1-512x512.py) |
+| fcn_unet_s5-d16 | unet | 512x512 | 0.001 | - | - | [config](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/projects/medical/2d_image/dermoscopy/isic2016_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_isic2016-task1-512x512.py) |
+| fcn_unet_s5-d16 | unet | 512x512 | 0.0001 | - | - | [config](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/projects/medical/2d_image/dermoscopy/isic2016_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_isic2016-task1-512x512.py) |
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+
+ - [x] Basic docstrings & proper citation
+
+ - [x] Test-time correctness
+
+ - [x] A full README
+
+- [x] Milestone 2: Indicates a successful model implementation.
+
+ - [x] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+
+ - [ ] Unit tests
+
+ - [ ] Code polishing
+
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/dermoscopy/isic2016_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_isic2016-task1-512x512.py b/projects/medical/2d_image/dermoscopy/isic2016_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_isic2016-task1-512x512.py
new file mode 100644
index 00000000000..5638de4d560
--- /dev/null
+++ b/projects/medical/2d_image/dermoscopy/isic2016_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_isic2016-task1-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './isic2016-task1_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.isic2016-task1_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/dermoscopy/isic2016_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_isic2016-task1-512x512.py b/projects/medical/2d_image/dermoscopy/isic2016_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_isic2016-task1-512x512.py
new file mode 100644
index 00000000000..bf17faa5380
--- /dev/null
+++ b/projects/medical/2d_image/dermoscopy/isic2016_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_isic2016-task1-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './isic2016-task1_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.isic2016-task1_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/dermoscopy/isic2016_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_isic2016-task1-512x512.py b/projects/medical/2d_image/dermoscopy/isic2016_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_isic2016-task1-512x512.py
new file mode 100644
index 00000000000..f7bfcf6158f
--- /dev/null
+++ b/projects/medical/2d_image/dermoscopy/isic2016_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_isic2016-task1-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './isic2016-task1_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.isic2016-task1_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/dermoscopy/isic2016_task1/configs/isic2016-task1_512x512.py b/projects/medical/2d_image/dermoscopy/isic2016_task1/configs/isic2016-task1_512x512.py
new file mode 100644
index 00000000000..029f5d4d7ec
--- /dev/null
+++ b/projects/medical/2d_image/dermoscopy/isic2016_task1/configs/isic2016-task1_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'ISIC2017Task1'
+data_root = 'data/'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='test.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/dermoscopy/isic2016_task1/datasets/isic2016-task1_dataset.py b/projects/medical/2d_image/dermoscopy/isic2016_task1/datasets/isic2016-task1_dataset.py
new file mode 100644
index 00000000000..8f11bdd0ba9
--- /dev/null
+++ b/projects/medical/2d_image/dermoscopy/isic2016_task1/datasets/isic2016-task1_dataset.py
@@ -0,0 +1,30 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class ISIC2017Task1(BaseSegDataset):
+ """ISIC2017Task1 dataset.
+
+ In segmentation map annotation for ISIC2017Task1,
+ ``reduce_zero_label`` is fixed to False. The ``img_suffix``
+ is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
+
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
+ Default to False.
+ """
+ METAINFO = dict(classes=('normal', 'skin lesion'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
diff --git a/projects/medical/2d_image/dermoscopy/isic2016_task1/tools/prepare_dataset.py b/projects/medical/2d_image/dermoscopy/isic2016_task1/tools/prepare_dataset.py
new file mode 100755
index 00000000000..ef4dad54086
--- /dev/null
+++ b/projects/medical/2d_image/dermoscopy/isic2016_task1/tools/prepare_dataset.py
@@ -0,0 +1,120 @@
+import glob
+import os
+import shutil
+
+import numpy as np
+from PIL import Image
+
+
+def check_maskid(train_imgs):
+ for i in train_masks:
+ img = Image.open(i)
+ print(np.unique(np.array(img)))
+
+
+def reformulate_file(image_list, mask_list):
+ file_list = []
+ for idx, (imgp,
+ maskp) in enumerate(zip(sorted(image_list), sorted(mask_list))):
+ item = {'image': imgp, 'label': maskp}
+ file_list.append(item)
+ return file_list
+
+
+def check_file_exist(pair_list):
+ rel_path = os.getcwd()
+ for idx, sample in enumerate(pair_list):
+ image_path = sample['image']
+ assert os.path.exists(os.path.join(rel_path, image_path))
+ if 'label' in sample:
+ mask_path = sample['label']
+ assert os.path.exists(os.path.join(rel_path, mask_path))
+ print('all file path ok!')
+
+
+def convert_maskid(mask):
+ # add mask id conversion
+ arr_mask = np.array(mask).astype(np.uint8)
+ arr_mask[arr_mask == 255] = 1
+ return Image.fromarray(arr_mask)
+
+
+def process_dataset(file_lists, part_dir_dict):
+ for ith, part in enumerate(file_lists):
+ part_dir = part_dir_dict[ith]
+ for sample in part:
+ # read image and mask
+ image_path = sample['image']
+ if 'label' in sample:
+ mask_path = sample['label']
+
+ basename = os.path.basename(image_path)
+ targetname = basename.split('.')[0] # from image name
+
+ # check image file
+ img_save_path = os.path.join(root_path, 'images', part_dir,
+ targetname + save_img_suffix)
+ if not os.path.exists(img_save_path):
+ if not image_path.endswith('.png'):
+ src = Image.open(image_path)
+ src.save(img_save_path)
+ else:
+ shutil.copy(image_path, img_save_path)
+
+ if mask_path is not None:
+ mask_save_path = os.path.join(root_path, 'masks', part_dir,
+ targetname + save_seg_map_suffix)
+ if not os.path.exists(mask_save_path):
+ # check mask file
+ mask = Image.open(mask_path).convert('L')
+ # convert mask id
+ mask = convert_maskid(mask)
+ if not mask_path.endswith('.png'):
+ mask.save(mask_save_path)
+ else:
+ mask.save(mask_save_path)
+
+ # print image num
+ part_dir_folder = os.path.join(root_path, 'images', part_dir)
+ print(
+ f'{part_dir} has {len(os.listdir(part_dir_folder))} images completed!' # noqa
+ )
+
+
+if __name__ == '__main__':
+
+ root_path = 'data/' # original file
+ img_suffix = '.jpg'
+ seg_map_suffix = '.png'
+ save_img_suffix = '.png'
+ save_seg_map_suffix = '.png'
+
+ train_imgs = glob.glob('data/ISBI2016_ISIC_Part1_Training_Data/*' # noqa
+ + img_suffix)
+ train_masks = glob.glob(
+ 'data/ISBI2016_ISIC_Part1_Training_GroundTruth/*' # noqa
+ + seg_map_suffix)
+
+ test_imgs = glob.glob('data/ISBI2016_ISIC_Part1_Test_Data/*' + img_suffix)
+ test_masks = glob.glob(
+ 'data/ISBI2016_ISIC_Part1_Test_GroundTruth/*' # noqa
+ + seg_map_suffix)
+
+ assert len(train_imgs) == len(train_masks)
+ assert len(test_imgs) == len(test_masks)
+
+ print(f'training images: {len(train_imgs)}, test images: {len(test_imgs)}')
+
+ os.system('mkdir -p ' + root_path + 'images/train/')
+ os.system('mkdir -p ' + root_path + 'images/test/')
+ os.system('mkdir -p ' + root_path + 'masks/train/')
+ os.system('mkdir -p ' + root_path + 'masks/test/')
+
+ train_pair_list = reformulate_file(train_imgs, train_masks)
+ test_pair_list = reformulate_file(test_imgs, test_masks)
+
+ check_file_exist(train_pair_list)
+ check_file_exist(test_pair_list)
+
+ part_dir_dict = {0: 'train/', 1: 'test/'}
+ process_dataset([train_pair_list, test_pair_list], part_dir_dict)
diff --git a/projects/medical/2d_image/dermoscopy/isic2017_task1/README.md b/projects/medical/2d_image/dermoscopy/isic2017_task1/README.md
new file mode 100644
index 00000000000..c7cc27096be
--- /dev/null
+++ b/projects/medical/2d_image/dermoscopy/isic2017_task1/README.md
@@ -0,0 +1,158 @@
+# ISIC-2017 Task1
+
+## Description
+
+This project support **`ISIC-2017 Task1 `**, and the dataset used in this project can be downloaded from [here](https://challenge.isic-archive.com/data/#2017).
+
+### Dataset Overview
+
+The goal of the challenge is to help participants develop image analysis tools to enable the automated diagnosis of melanoma from dermoscopic images.
+
+This challenge provides training data (~2000 images) for participants to engage in all 3 components of lesion image analysis. A separate public validation dataset (~150 images) and blind held-out test dataset (~600 images) will be provided for participants to generate and submit automated results.
+
+### Original Statistic Information
+
+| Dataset name | Anatomical region | Task type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
+| ---------------------------------------------------------------- | ----------------- | ------------ | ---------- | ------------ | --------------------- | ---------------------- | ------------ | ---------------------------------------------------------------------- |
+| [ISIC-2017 Task1](https://challenge.isic-archive.com/data/#2017) | full body | segmentation | dermoscopy | 2 | 2000/150/600 | yes/yes/yes | 2017 | [CC-0](https://creativecommons.org/share-your-work/public-domain/cc0/) |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :---------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| normal | 2000 | 82.86 | 150 | 73.88 | 600 | 70.62 |
+| skin lesion | 2000 | 17.14 | 150 | 26.12 | 600 | 29.38 |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![bac](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/dermoscopy/isic2017_task1/isic2017_task1.png)
+
+### Prerequisites
+
+- Python 3.8
+- PyTorch 1.10.0
+- pillow(PIL) 9.3.0
+- scikit-learn(sklearn) 1.2.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of PYTHONPATH, which should point to the project's directory so that Python can locate the module files. In isic2017_task1/ root directory, run the following line to add the current directory to PYTHONPATH:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset preparing
+
+- download dataset from [here](https://challenge.isic-archive.com/data/#2017) and decompression data to path 'data/'.
+- run script `"python tools/prepare_dataset.py"` to split dataset and change folder structure as below.
+- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt` and `test.txt`. If the label of official validation set and test set can't be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── dermoscopy
+ │ │ │ │ ├── isic2017_task1
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── val.txt
+ │ │ │ │ │ │ ├── test.txt
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ │ ├── val
+ │ │ │ │ | │ │ │ ├── yyy.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── yyy.png
+ │ │ │ │ │ │ │ ├── test
+ │ │ │ │ | │ │ │ ├── yyy.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── yyy.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ │ ├── val
+ │ │ │ │ | │ │ │ ├── yyy.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── yyy.png
+ │ │ │ │ │ │ │ ├── test
+ │ │ │ │ | │ │ │ ├── yyy.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── yyy.png
+```
+
+### Training commands
+
+```shell
+mim train mmseg ./configs/${CONFIG_PATH}
+```
+
+To train on multiple GPUs, e.g. 8 GPUs, run the following command:
+
+```shell
+mim train mmseg ./configs/${CONFIG_PATH} --launcher pytorch --gpus 8
+```
+
+### Testing commands
+
+```shell
+mim test mmseg ./configs/${CONFIG_PATH} --checkpoint ${CHECKPOINT_PATH}
+```
+
+
+
+## Results
+
+### ISIC-2017 Task1
+
+| Method | Backbone | Crop Size | lr | mIoU | mDice | config |
+| :-------------: | :------: | :-------: | :----: | :--: | :---: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| fcn_unet_s5-d16 | unet | 512x512 | 0.01 | - | - | [config](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/projects/medical/2d_image/dermoscopy/isic2017_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_isic2017-task1-512x512.py) |
+| fcn_unet_s5-d16 | unet | 512x512 | 0.001 | - | - | [config](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/projects/medical/2d_image/dermoscopy/isic2017_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_isic2017-task1-512x512.py) |
+| fcn_unet_s5-d16 | unet | 512x512 | 0.0001 | - | - | [config](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/projects/medical/2d_image/dermoscopy/isic2017_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_isic2017-task1-512x512.py) |
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+
+ - [x] Basic docstrings & proper citation
+
+ - [ ] Test-time correctness
+
+ - [x] A full README
+
+- [ ] Milestone 2: Indicates a successful model implementation.
+
+ - [ ] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+
+ - [ ] Unit tests
+
+ - [ ] Code polishing
+
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/dermoscopy/isic2017_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_isic2017-task1-512x512.py b/projects/medical/2d_image/dermoscopy/isic2017_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_isic2017-task1-512x512.py
new file mode 100644
index 00000000000..58d0a125d33
--- /dev/null
+++ b/projects/medical/2d_image/dermoscopy/isic2017_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_isic2017-task1-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './isic2017-task1_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.isic2017-task1_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/dermoscopy/isic2017_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_isic2017-task1-512x512.py b/projects/medical/2d_image/dermoscopy/isic2017_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_isic2017-task1-512x512.py
new file mode 100644
index 00000000000..3becacf64fb
--- /dev/null
+++ b/projects/medical/2d_image/dermoscopy/isic2017_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_isic2017-task1-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './isic2017-task1_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.isic2017-task1_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/dermoscopy/isic2017_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_isic2017-task1-512x512.py b/projects/medical/2d_image/dermoscopy/isic2017_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_isic2017-task1-512x512.py
new file mode 100644
index 00000000000..654ef4dc3d5
--- /dev/null
+++ b/projects/medical/2d_image/dermoscopy/isic2017_task1/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_isic2017-task1-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './isic2017-task1_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.isic2017-task1_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/dermoscopy/isic2017_task1/configs/isic2017-task1_512x512.py b/projects/medical/2d_image/dermoscopy/isic2017_task1/configs/isic2017-task1_512x512.py
new file mode 100644
index 00000000000..95997a10997
--- /dev/null
+++ b/projects/medical/2d_image/dermoscopy/isic2017_task1/configs/isic2017-task1_512x512.py
@@ -0,0 +1,41 @@
+dataset_type = 'ISIC2017Task1'
+data_root = 'data/'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='images/train/', seg_map_path='masks/train/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(img_path='images/val/', seg_map_path='masks/val/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/dermoscopy/isic2017_task1/datasets/isic2017-task1_dataset.py b/projects/medical/2d_image/dermoscopy/isic2017_task1/datasets/isic2017-task1_dataset.py
new file mode 100644
index 00000000000..8f11bdd0ba9
--- /dev/null
+++ b/projects/medical/2d_image/dermoscopy/isic2017_task1/datasets/isic2017-task1_dataset.py
@@ -0,0 +1,30 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class ISIC2017Task1(BaseSegDataset):
+ """ISIC2017Task1 dataset.
+
+ In segmentation map annotation for ISIC2017Task1,
+ ``reduce_zero_label`` is fixed to False. The ``img_suffix``
+ is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
+
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
+ Default to False.
+ """
+ METAINFO = dict(classes=('normal', 'skin lesion'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
diff --git a/projects/medical/2d_image/dermoscopy/isic2017_task1/tools/prepare_dataset.py b/projects/medical/2d_image/dermoscopy/isic2017_task1/tools/prepare_dataset.py
new file mode 100755
index 00000000000..b3643c93590
--- /dev/null
+++ b/projects/medical/2d_image/dermoscopy/isic2017_task1/tools/prepare_dataset.py
@@ -0,0 +1,127 @@
+import glob
+import os
+import shutil
+
+import numpy as np
+from PIL import Image
+
+
+def check_maskid(train_imgs):
+ for i in train_masks:
+ img = Image.open(i)
+ print(np.unique(np.array(img)))
+
+
+def reformulate_file(image_list, mask_list):
+ file_list = []
+ for idx, (imgp,
+ maskp) in enumerate(zip(sorted(image_list), sorted(mask_list))):
+ item = {'image': imgp, 'label': maskp}
+ file_list.append(item)
+ return file_list
+
+
+def convert_maskid(mask):
+ # add mask id conversion
+ arr_mask = np.array(mask).astype(np.uint8)
+ arr_mask[arr_mask == 255] = 1
+ return Image.fromarray(arr_mask)
+
+
+def check_file_exist(pair_list):
+ rel_path = os.getcwd()
+ for idx, sample in enumerate(pair_list):
+ image_path = sample['image']
+ assert os.path.exists(os.path.join(rel_path, image_path))
+ if 'label' in sample:
+ mask_path = sample['label']
+ assert os.path.exists(os.path.join(rel_path, mask_path))
+ print('all file path ok!')
+
+
+def process_dataset(file_lists, part_dir_dict):
+ for ith, part in enumerate(file_lists):
+ part_dir = part_dir_dict[ith]
+ for sample in part:
+ # read image and mask
+ image_path = sample['image']
+ if 'label' in sample:
+ mask_path = sample['label']
+
+ basename = os.path.basename(image_path)
+ targetname = basename.split('.')[0] # from image name
+
+ # check image file
+ img_save_path = os.path.join(root_path, 'images', part_dir,
+ targetname + save_img_suffix)
+ if not os.path.exists(img_save_path):
+ if not image_path.endswith('.png'):
+ src = Image.open(image_path)
+ src.save(img_save_path)
+ else:
+ shutil.copy(image_path, img_save_path)
+
+ if mask_path is not None:
+ mask_save_path = os.path.join(root_path, 'masks', part_dir,
+ targetname + save_seg_map_suffix)
+ if not os.path.exists(mask_save_path):
+ # check mask file
+ mask = Image.open(mask_path).convert('L')
+ # convert mask id
+ mask = convert_maskid(mask)
+ if not mask_path.endswith('.png'):
+ mask.save(mask_save_path)
+ else:
+ mask.save(mask_save_path)
+
+ # print image num
+ part_dir_folder = os.path.join(root_path, 'images', part_dir)
+ print(
+ f'{part_dir} has {len(os.listdir(part_dir_folder))} images completed!' # noqa
+ )
+
+
+if __name__ == '__main__':
+
+ root_path = 'data/' # original file
+ img_suffix = '.jpg'
+ seg_map_suffix = '.png'
+ save_img_suffix = '.png'
+ save_seg_map_suffix = '.png'
+
+ train_imgs = glob.glob('data/ISIC-2017_Training_Data/*' + img_suffix)
+ train_masks = glob.glob('data/ISIC-2017_Training_Part1_GroundTruth/*' +
+ seg_map_suffix)
+
+ val_imgs = glob.glob('data/ISIC-2017_Validation_Data/*' + img_suffix)
+ val_masks = glob.glob('data/ISIC-2017_Validation_Part1_GroundTruth/*' +
+ seg_map_suffix)
+
+ test_imgs = glob.glob('data/ISIC-2017_Test_v2_Data/*' + img_suffix)
+ test_masks = glob.glob('data/ISIC-2017_Test_v2_Part1_GroundTruth/*' +
+ seg_map_suffix)
+
+ assert len(train_imgs) == len(train_masks)
+ assert len(val_imgs) == len(val_masks)
+ assert len(test_imgs) == len(test_masks)
+
+ os.system('mkdir -p ' + root_path + 'images/train/')
+ os.system('mkdir -p ' + root_path + 'images/val/')
+ os.system('mkdir -p ' + root_path + 'images/test/')
+ os.system('mkdir -p ' + root_path + 'masks/train/')
+ os.system('mkdir -p ' + root_path + 'masks/val/')
+ os.system('mkdir -p ' + root_path + 'masks/test/')
+
+ part_dir_dict = {0: 'train/', 1: 'val/', 2: 'test/'}
+
+ train_pair_list = reformulate_file(train_imgs, train_masks)
+ val_pair_list = reformulate_file(val_imgs, val_masks)
+ test_pair_list = reformulate_file(test_imgs, test_masks)
+
+ check_file_exist(train_pair_list)
+ check_file_exist(val_pair_list)
+ check_file_exist(test_pair_list)
+
+ part_dir_dict = {0: 'train/', 1: 'val/', 2: 'test/'}
+ process_dataset([train_pair_list, val_pair_list, test_pair_list],
+ part_dir_dict)
diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg/README.md b/projects/medical/2d_image/endoscopy/kvasir_seg/README.md
new file mode 100644
index 00000000000..ea597bc4401
--- /dev/null
+++ b/projects/medical/2d_image/endoscopy/kvasir_seg/README.md
@@ -0,0 +1,145 @@
+# Kvasir-Sessile Dataset (Kvasir SEG)
+
+## Description
+
+This project supports **`Kvasir-Sessile Dataset (Kvasir SEG) `**, which can be downloaded from [here](https://opendatalab.com/Kvasir-Sessile_dataset).
+
+## Dataset Overview
+
+The Kvasir-SEG dataset contains polyp images and their corresponding ground truth from the Kvasir Dataset v2. The resolution of the images contained in Kvasir-SEG varies from 332x487 to 1920x1072 pixels.
+
+
+
+### Information Statistics
+
+| Dataset Name | Anatomical Region | Task Type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
+| ------------------------------------------------------------- | ----------------- | ------------ | --------- | ------------ | --------------------- | ---------------------- | ------------ | --------------------------------------------------------- |
+| [Kvarsir-SEG](https://opendatalab.com/Kvasir-Sessile_dataset) | abdomen | segmentation | endoscopy | 2 | 196/-/- | yes/-/- | 2020 | [CC-BY 4.0](https://creativecommons.org/licenses/by/4.0/) |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 196 | 92.31 | - | - | - | - |
+| polyp | 196 | 7.69 | - | - | - | - |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![kvasir-seg](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/endoscopy_images/kvasir_seg/kvasir_seg_dataset.png?raw=true)
+
+### Dataset Citation
+
+```
+@inproceedings{jha2020kvasir,
+ title={Kvasir-seg: A segmented polyp dataset},
+ author={Jha, Debesh and Smedsrud, Pia H and Riegler, Michael A and Halvorsen, P{\aa}l and Lange, Thomas de and Johansen, Dag and Johansen, H{\aa}vard D},
+ booktitle={International Conference on Multimedia Modeling},
+ pages={451--462},
+ year={2020},
+ organization={Springer}
+ }
+```
+
+### Prerequisites
+
+- Python v3.8
+- PyTorch v1.10.0
+- pillow(PIL) v9.3.0
+- scikit-learn(sklearn) v1.2.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `kvasir_seg/` root directory, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset Preparing
+
+- download dataset from [here](https://opendatalab.com/Kvasir-Sessile_dataset) and decompress data to path `'data/'`.
+- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below.
+- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set cannot be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── endoscopy
+ │ │ │ │ ├── kvasir_seg
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── val.txt
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+```
+
+### Divided Dataset Information
+
+***Note: The table information below is divided by ourselves.***
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 156 | 92.28 | 40 | 92.41 | - | - |
+| polyp | 156 | 7.72 | 40 | 7.59 | - | - |
+
+### Training commands
+
+To train models on a single server with one GPU. (default)
+
+```shell
+mim train mmseg .configs/${CONFIG_FILE}
+```
+
+### Testing commands
+
+To test models on a single server with one GPU. (default)
+
+```shell
+mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH}
+```
+
+
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+ - [x] Basic docstrings & proper citation
+ - [ ] Test-time correctness
+ - [x] A full README
+
+- [x] Milestone 2: Indicates a successful model implementation.
+
+ - [x] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+ - [ ] Unit tests
+ - [ ] Code polishing
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_kvasir-seg-512x512.py b/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_kvasir-seg-512x512.py
new file mode 100644
index 00000000000..145d5a7a172
--- /dev/null
+++ b/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_kvasir-seg-512x512.py
@@ -0,0 +1,18 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './kvasir-seg_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.kvasir-seg_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(
+ num_classes=2, loss_decode=dict(use_sigmoid=True), out_channels=1),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_kvasir-seg-512x512.py b/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_kvasir-seg-512x512.py
new file mode 100644
index 00000000000..3ea05c51090
--- /dev/null
+++ b/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_kvasir-seg-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './kvasir-seg_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.kvasir-seg_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_kvasir-seg-512x512.py b/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_kvasir-seg-512x512.py
new file mode 100644
index 00000000000..7e064a716aa
--- /dev/null
+++ b/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_kvasir-seg-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './kvasir-seg_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.kvasir-seg_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_kvasir-seg-512x512.py b/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_kvasir-seg-512x512.py
new file mode 100644
index 00000000000..0fc1d6e99d7
--- /dev/null
+++ b/projects/medical/2d_image/endoscopy/kvasir_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_kvasir-seg-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './kvasir-seg_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.kvasir-seg_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg/configs/kvasir-seg_512x512.py b/projects/medical/2d_image/endoscopy/kvasir_seg/configs/kvasir-seg_512x512.py
new file mode 100644
index 00000000000..e8b2467f8cf
--- /dev/null
+++ b/projects/medical/2d_image/endoscopy/kvasir_seg/configs/kvasir-seg_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'KvasirSEGDataset'
+data_root = 'data/'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='val.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg/datasets/kvasir-seg_dataset.py b/projects/medical/2d_image/endoscopy/kvasir_seg/datasets/kvasir-seg_dataset.py
new file mode 100644
index 00000000000..9d601328eb6
--- /dev/null
+++ b/projects/medical/2d_image/endoscopy/kvasir_seg/datasets/kvasir-seg_dataset.py
@@ -0,0 +1,30 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class KvasirSEGDataset(BaseSegDataset):
+ """KvasirSEGDataset dataset.
+
+ In segmentation map annotation for KvasirSEGDataset, 0 stands for
+ background, which is included in 2 categories.
+ ``reduce_zero_label`` is fixed to False. The ``img_suffix`` is
+ fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
+ Default to False..
+ """
+ METAINFO = dict(classes=('background', 'polyp'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg/tools/prepare_dataset.py b/projects/medical/2d_image/endoscopy/kvasir_seg/tools/prepare_dataset.py
new file mode 100644
index 00000000000..74c43e96351
--- /dev/null
+++ b/projects/medical/2d_image/endoscopy/kvasir_seg/tools/prepare_dataset.py
@@ -0,0 +1,87 @@
+import glob
+import os
+
+import numpy as np
+from PIL import Image
+
+root_path = 'data/'
+img_suffix = '.jpg'
+seg_map_suffix = '.jpg'
+save_img_suffix = '.png'
+save_seg_map_suffix = '.png'
+tgt_img_dir = os.path.join(root_path, 'images/train/')
+tgt_mask_dir = os.path.join(root_path, 'masks/train/')
+os.system('mkdir -p ' + tgt_img_dir)
+os.system('mkdir -p ' + tgt_mask_dir)
+
+
+def filter_suffix_recursive(src_dir, suffix):
+ # filter out file names and paths in source directory
+ suffix = '.' + suffix if '.' not in suffix else suffix
+ file_paths = glob.glob(
+ os.path.join(src_dir, '**', '*' + suffix), recursive=True)
+ file_names = [_.split('/')[-1] for _ in file_paths]
+ return sorted(file_paths), sorted(file_names)
+
+
+def convert_label(img, convert_dict):
+ arr = np.zeros_like(img, dtype=np.uint8)
+ for c, i in convert_dict.items():
+ arr[img == c] = i
+ return arr
+
+
+def convert_pics_into_pngs(src_dir, tgt_dir, suffix, convert='RGB'):
+ if not os.path.exists(tgt_dir):
+ os.makedirs(tgt_dir)
+
+ src_paths, src_names = filter_suffix_recursive(src_dir, suffix=suffix)
+ for i, (src_name, src_path) in enumerate(zip(src_names, src_paths)):
+ tgt_name = src_name.replace(suffix, save_img_suffix)
+ tgt_path = os.path.join(tgt_dir, tgt_name)
+ num = len(src_paths)
+ img = np.array(Image.open(src_path))
+ if len(img.shape) == 2:
+ pil = Image.fromarray(img).convert(convert)
+ elif len(img.shape) == 3:
+ pil = Image.fromarray(img)
+ else:
+ raise ValueError('Input image not 2D/3D: ', img.shape)
+
+ pil.save(tgt_path)
+ print(f'processed {i+1}/{num}.')
+
+
+def convert_label_pics_into_pngs(src_dir,
+ tgt_dir,
+ suffix,
+ convert_dict={
+ 0: 0,
+ 255: 1
+ }):
+ if not os.path.exists(tgt_dir):
+ os.makedirs(tgt_dir)
+
+ src_paths, src_names = filter_suffix_recursive(src_dir, suffix=suffix)
+ num = len(src_paths)
+ for i, (src_name, src_path) in enumerate(zip(src_names, src_paths)):
+ tgt_name = src_name.replace(suffix, save_seg_map_suffix)
+ tgt_path = os.path.join(tgt_dir, tgt_name)
+
+ img = np.array(Image.open(src_path))
+ img = convert_label(img, convert_dict)
+ Image.fromarray(img).save(tgt_path)
+ print(f'processed {i+1}/{num}.')
+
+
+if __name__ == '__main__':
+
+ convert_pics_into_pngs(
+ os.path.join(root_path, 'sessile-main-Kvasir-SEG/images'),
+ tgt_img_dir,
+ suffix=img_suffix)
+
+ convert_label_pics_into_pngs(
+ os.path.join(root_path, 'sessile-main-Kvasir-SEG/masks'),
+ tgt_mask_dir,
+ suffix=seg_map_suffix)
diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/README.md b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/README.md
new file mode 100644
index 00000000000..80eb00f51bd
--- /dev/null
+++ b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/README.md
@@ -0,0 +1,145 @@
+# Kvasir-SEG Segmented Polyp Dataset from Aliyun (Kvasir SEG Aliyun)
+
+## Description
+
+This project supports **`Kvasir-SEG Segmented Polyp Dataset from Aliyun (Kvasir SEG Aliyun) `**, which can be downloaded from [here](https://tianchi.aliyun.com/dataset/84385).
+
+### Dataset Overview
+
+Colorectal cancer is the second most common cancer type among women and third most common among men. Polyps are precursors to colorectal cancer and therefore important to detect and remove at an early stage. Polyps are found in nearly half of the individuals at age 50 that undergo a colonoscopy screening, and their frequency increase with age.Polyps are abnormal tissue growth from the mucous membrane, which is lining the inside of the GI tract, and can sometimes be cancerous. Colonoscopy is the gold standard for detection and assessment of these polyps with subsequent biopsy and removal of the polyps. Early disease detection has a huge impact on survival from colorectal cancer. Increasing the detection of polyps has been shown to decrease risk of colorectal cancer. Thus, automatic detection of more polyps at an early stage can play a crucial role in prevention and survival from colorectal cancer.
+
+The Kvasir-SEG dataset is based on the previous Kvasir dataset, which is the first multi-class dataset for gastrointestinal (GI) tract disease detection and classification. It contains annotated polyp images and their corresponding masks. The pixels depicting polyp tissue, the ROI, are represented by the foreground (white mask), while the background (in black) does not contain positive pixels. These images were collected and verified by experienced gastroenterologists from Vestre Viken Health Trust in Norway. The classes include anatomical landmarks, pathological findings and endoscopic procedures.
+
+### Information Statistics
+
+| Dataset Name | Anatomical Region | Task Type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
+| ------------------------------------------------------ | ----------------- | ------------ | --------- | ------------ | --------------------- | ---------------------- | ------------ | --------------------------------------------------------- |
+| [kvasir-seg](https://tianchi.aliyun.com/dataset/84385) | abdomen | segmentation | endoscopy | 2 | 1000/-/- | yes/-/- | 2020 | [CC-BY 4.0](https://creativecommons.org/licenses/by/4.0/) |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 1000 | 84.72 | - | - | - | - |
+| polyp | 1000 | 15.28 | - | - | - | - |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![kvasir_seg_aliyun](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/endoscopy_images/kvasir_seg_aliyun/kvasir_seg_aliyun_dataset.png?raw=true)
+
+### Dataset Citation
+
+```
+@inproceedings{jha2020kvasir,
+ title={Kvasir-seg: A segmented polyp dataset},
+ author={Jha, Debesh and Smedsrud, Pia H and Riegler, Michael A and Halvorsen, P{\aa}l and Lange, Thomas de and Johansen, Dag and Johansen, H{\aa}vard D},
+ booktitle={International Conference on Multimedia Modeling},
+ pages={451--462},
+ year={2020},
+ organization={Springer}
+ }
+```
+
+### Prerequisites
+
+- Python v3.8
+- PyTorch v1.10.0
+- pillow(PIL) v9.3.0
+- scikit-learn(sklearn) v1.2.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `kvasir_seg_aliyun/` root directory, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset Preparing
+
+- download dataset from [here](https://tianchi.aliyun.com/dataset/84385) and decompression data to path 'data/.'.
+- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below.
+- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set cannot be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── endoscopy
+ │ │ │ │ ├── kvasir_seg_aliyun
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── val.txt
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+```
+
+### Divided Dataset Information
+
+***Note: The table information below is divided by ourselves.***
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 800 | 84.66 | 200 | 84.94 | - | - |
+| polyp | 800 | 15.34 | 200 | 15.06 | - | - |
+
+### Training commands
+
+To train models on a single server with one GPU. (default)
+
+```shell
+mim train mmseg ./configs/${CONFIG_FILE}
+```
+
+### Testing commands
+
+To test models on a single server with one GPU. (default)
+
+```shell
+mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH}
+```
+
+
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+ - [x] Basic docstrings & proper citation
+ - [ ] Test-time correctness
+ - [x] A full README
+
+- [ ] Milestone 2: Indicates a successful model implementation.
+
+ - [ ] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+ - [ ] Unit tests
+ - [ ] Code polishing
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_kvasir-seg-aliyun-512x512.py b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_kvasir-seg-aliyun-512x512.py
new file mode 100644
index 00000000000..b59db95232b
--- /dev/null
+++ b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_kvasir-seg-aliyun-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ './kvasir-seg-aliyun_512x512.py', 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.kvasir-seg-aliyun_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_kvasir-seg-aliyun-512x512.py b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_kvasir-seg-aliyun-512x512.py
new file mode 100644
index 00000000000..6c526680cd4
--- /dev/null
+++ b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_kvasir-seg-aliyun-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ './kvasir-seg-aliyun_512x512.py', 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.kvasir-seg-aliyun_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_kvasir-seg-aliyun-512x512.py b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_kvasir-seg-aliyun-512x512.py
new file mode 100644
index 00000000000..a192a5bd240
--- /dev/null
+++ b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_kvasir-seg-aliyun-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ './kvasir-seg-aliyun_512x512.py', 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.kvasir-seg-aliyun_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.01lr-sigmoid-20k_kvasir-seg-aliyun-512x512.py b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.01lr-sigmoid-20k_kvasir-seg-aliyun-512x512.py
new file mode 100644
index 00000000000..5325e1f0807
--- /dev/null
+++ b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/fcn-unet-s5-d16_unet_1xb16-0.01lr-sigmoid-20k_kvasir-seg-aliyun-512x512.py
@@ -0,0 +1,18 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ './kvasir-seg-aliyun_512x512.py', 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.kvasir-seg-aliyun_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(
+ num_classes=2, loss_decode=dict(use_sigmoid=True), out_channels=1),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/kvasir-seg-aliyun_512x512.py b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/kvasir-seg-aliyun_512x512.py
new file mode 100644
index 00000000000..5f868804679
--- /dev/null
+++ b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/configs/kvasir-seg-aliyun_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'KvasirSEGAliyunDataset'
+data_root = 'data/'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='val.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/datasets/kvasir-seg-aliyun_dataset.py b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/datasets/kvasir-seg-aliyun_dataset.py
new file mode 100644
index 00000000000..198caf07bcd
--- /dev/null
+++ b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/datasets/kvasir-seg-aliyun_dataset.py
@@ -0,0 +1,30 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class KvasirSEGAliyunDataset(BaseSegDataset):
+ """KvasirSEGAliyunDataset dataset.
+
+ In segmentation map annotation for KvasirSEGAliyunDataset,
+ 0 stands for background,which is included in 2 categories.
+ ``reduce_zero_label`` is fixed to False. The ``img_suffix``
+ is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
+ Default to False..
+ """
+ METAINFO = dict(classes=('background', 'polyp'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
diff --git a/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/tools/prepare_dataset.py b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/tools/prepare_dataset.py
new file mode 100644
index 00000000000..b230e7fef58
--- /dev/null
+++ b/projects/medical/2d_image/endoscopy/kvasir_seg_aliyun/tools/prepare_dataset.py
@@ -0,0 +1,86 @@
+import glob
+import os
+
+import numpy as np
+from PIL import Image
+
+root_path = 'data/'
+img_suffix = '.jpg'
+seg_map_suffix = '.jpg'
+save_img_suffix = '.png'
+save_seg_map_suffix = '.png'
+tgt_img_dir = os.path.join(root_path, 'images/train/')
+tgt_mask_dir = os.path.join(root_path, 'masks/train/')
+os.system('mkdir -p ' + tgt_img_dir)
+os.system('mkdir -p ' + tgt_mask_dir)
+
+
+def filter_suffix_recursive(src_dir, suffix):
+ # filter out file names and paths in source directory
+ suffix = '.' + suffix if '.' not in suffix else suffix
+ file_paths = glob.glob(
+ os.path.join(src_dir, '**', '*' + suffix), recursive=True)
+ file_names = [_.split('/')[-1] for _ in file_paths]
+ return sorted(file_paths), sorted(file_names)
+
+
+def convert_label(img, convert_dict):
+ arr = np.zeros_like(img, dtype=np.uint8)
+ for c, i in convert_dict.items():
+ arr[img == c] = i
+ return arr
+
+
+def convert_pics_into_pngs(src_dir, tgt_dir, suffix, convert='RGB'):
+ if not os.path.exists(tgt_dir):
+ os.makedirs(tgt_dir)
+
+ src_paths, src_names = filter_suffix_recursive(src_dir, suffix=suffix)
+ for i, (src_name, src_path) in enumerate(zip(src_names, src_paths)):
+ tgt_name = src_name.replace(suffix, save_img_suffix)
+ tgt_path = os.path.join(tgt_dir, tgt_name)
+ num = len(src_paths)
+ img = np.array(Image.open(src_path))
+ if len(img.shape) == 2:
+ pil = Image.fromarray(img).convert(convert)
+ elif len(img.shape) == 3:
+ pil = Image.fromarray(img)
+ else:
+ raise ValueError('Input image not 2D/3D: ', img.shape)
+
+ pil.save(tgt_path)
+ print(f'processed {i+1}/{num}.')
+
+
+def convert_label_pics_into_pngs(src_dir,
+ tgt_dir,
+ suffix,
+ convert_dict={
+ 0: 0,
+ 255: 1
+ }):
+ if not os.path.exists(tgt_dir):
+ os.makedirs(tgt_dir)
+
+ src_paths, src_names = filter_suffix_recursive(src_dir, suffix=suffix)
+ num = len(src_paths)
+ for i, (src_name, src_path) in enumerate(zip(src_names, src_paths)):
+ tgt_name = src_name.replace(suffix, save_seg_map_suffix)
+ tgt_path = os.path.join(tgt_dir, tgt_name)
+
+ img = np.array(Image.open(src_path).convert('L'))
+ img = convert_label(img, convert_dict)
+ Image.fromarray(img).save(tgt_path)
+ print(f'processed {i+1}/{num}.')
+
+
+if __name__ == '__main__':
+ convert_pics_into_pngs(
+ os.path.join(root_path, 'Kvasir-SEG/images'),
+ tgt_img_dir,
+ suffix=img_suffix)
+
+ convert_label_pics_into_pngs(
+ os.path.join(root_path, 'Kvasir-SEG/masks'),
+ tgt_mask_dir,
+ suffix=seg_map_suffix)
diff --git a/projects/medical/2d_image/fluorescein_angriogram/vampire/README.md b/projects/medical/2d_image/fluorescein_angriogram/vampire/README.md
new file mode 100644
index 00000000000..c2c61c46a0b
--- /dev/null
+++ b/projects/medical/2d_image/fluorescein_angriogram/vampire/README.md
@@ -0,0 +1,158 @@
+# Vessel Assessment and Measurement Platform for Images of the REtina
+
+## Description
+
+This project support **`Vessel Assessment and Measurement Platform for Images of the REtina`**, and the dataset used in this project can be downloaded from [here](https://vampire.computing.dundee.ac.uk/vesselseg.html).
+
+### Dataset Overview
+
+In order to promote evaluation of vessel segmentation on ultra-wide field-of-view (UWFV) fluorescein angriogram (FA) frames, we make public 8 frames from two different sequences, the manually annotated images and the result of our automatic vessel segmentation algorithm.
+
+### Original Statistic Information
+
+| Dataset name | Anatomical region | Task type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
+| ---------------------------------------------------------------- | ----------------- | ------------ | ---------------------- | ------------ | --------------------- | ---------------------- | ------------ | --------------------------------------------------------------- |
+| [Vampire](https://vampire.computing.dundee.ac.uk/vesselseg.html) | vessel | segmentation | fluorescein angriogram | 2 | 8/-/- | yes/-/- | 2017 | [CC-BY-NC 4.0](https://creativecommons.org/licenses/by-sa/4.0/) |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 8 | 96.75 | - | - | - | - |
+| vessel | 8 | 3.25 | - | - | - | - |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![bac](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/fluorescein_angriogram/vampire/vampire_dataset.png)
+
+## Dataset Citation
+
+```bibtex
+
+@inproceedings{perez2011improving,
+ title={Improving vessel segmentation in ultra-wide field-of-view retinal fluorescein angiograms},
+ author={Perez-Rovira, Adria and Zutis, K and Hubschman, Jean Pierre and Trucco, Emanuele},
+ booktitle={2011 Annual International Conference of the IEEE Engineering in Medicine and Biology Society},
+ pages={2614--2617},
+ year={2011},
+ organization={IEEE}
+}
+
+@article{perez2011rerbee,
+ title={RERBEE: robust efficient registration via bifurcations and elongated elements applied to retinal fluorescein angiogram sequences},
+ author={Perez-Rovira, Adria and Cabido, Raul and Trucco, Emanuele and McKenna, Stephen J and Hubschman, Jean Pierre},
+ journal={IEEE Transactions on Medical Imaging},
+ volume={31},
+ number={1},
+ pages={140--150},
+ year={2011},
+ publisher={IEEE}
+}
+
+```
+
+### Prerequisites
+
+- Python v3.8
+- PyTorch v1.10.0
+- pillow(PIL) v9.3.0
+- scikit-learn(sklearn) v1.2.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `vampire/` root directory, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset preparing
+
+- download dataset from [here](https://vampire.computing.dundee.ac.uk/vesselseg.html) and decompression data to path `'data/'`.
+- run script `"python tools/prepare_dataset.py"` to split dataset and change folder structure as below.
+- run script `python ../../tools/split_seg_dataset.py` to split dataset. For the Bacteria_detection dataset, as there is no test or validation dataset, we sample 20% samples from the whole dataset as the validation dataset and 80% samples for training data and make two filename lists `train.txt` and `val.txt`. As we set the random seed as the hard code, we eliminated the randomness, the dataset split actually can be reproducible.
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── fluorescein_angriogram
+ │ │ │ │ ├── vampire
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── val.txt
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+```
+
+### Divided Dataset Information
+
+***Note: The table information below is divided by ourselves.***
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 6 | 97.48 | 2 | 94.54 | - | - |
+| vessel | 6 | 2.52 | 2 | 5.46 | - | - |
+
+### Training commands
+
+To train models on a single server with one GPU. (default)
+
+```shell
+mim train mmseg ./configs/${CONFIG_PATH}
+```
+
+### Testing commands
+
+To test models on a single server with one GPU. (default)
+
+```shell
+mim test mmseg ./configs/${CONFIG_PATH} --checkpoint ${CHECKPOINT_PATH}
+```
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+
+ - [x] Basic docstrings & proper citation
+
+ - [ ] Test-time correctness
+
+ - [x] A full README
+
+- [ ] Milestone 2: Indicates a successful model implementation.
+
+ - [ ] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+
+ - [ ] Unit tests
+
+ - [ ] Code polishing
+
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/fluorescein_angriogram/vampire/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_vampire-512x512.py b/projects/medical/2d_image/fluorescein_angriogram/vampire/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_vampire-512x512.py
new file mode 100755
index 00000000000..7f5273aaff9
--- /dev/null
+++ b/projects/medical/2d_image/fluorescein_angriogram/vampire/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_vampire-512x512.py
@@ -0,0 +1,19 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './vampire_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.vampire_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ type='EncoderDecoder',
+ data_preprocessor=data_preprocessor,
+ pretrained=None,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/fluorescein_angriogram/vampire/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_vampire-512x512.py b/projects/medical/2d_image/fluorescein_angriogram/vampire/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_vampire-512x512.py
new file mode 100755
index 00000000000..4382229989b
--- /dev/null
+++ b/projects/medical/2d_image/fluorescein_angriogram/vampire/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_vampire-512x512.py
@@ -0,0 +1,19 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './vampire_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.vampire_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ type='EncoderDecoder',
+ data_preprocessor=dict(size=img_scale),
+ pretrained=None,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/fluorescein_angriogram/vampire/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_vampire-512x512.py b/projects/medical/2d_image/fluorescein_angriogram/vampire/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_vampire-512x512.py
new file mode 100755
index 00000000000..8d93e17627f
--- /dev/null
+++ b/projects/medical/2d_image/fluorescein_angriogram/vampire/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_vampire-512x512.py
@@ -0,0 +1,22 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './vampire_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.vampire_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ type='EncoderDecoder',
+ data_preprocessor=data_preprocessor,
+ pretrained=None,
+ decode_head=dict(
+ num_classes=2,
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=True),
+ out_channels=1),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/fluorescein_angriogram/vampire/configs/vampire_512x512.py b/projects/medical/2d_image/fluorescein_angriogram/vampire/configs/vampire_512x512.py
new file mode 100755
index 00000000000..4eda92f9f22
--- /dev/null
+++ b/projects/medical/2d_image/fluorescein_angriogram/vampire/configs/vampire_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'VampireDataset'
+data_root = 'data'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='val.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/fluorescein_angriogram/vampire/datasets/__init__.py b/projects/medical/2d_image/fluorescein_angriogram/vampire/datasets/__init__.py
new file mode 100755
index 00000000000..93f9cbf0506
--- /dev/null
+++ b/projects/medical/2d_image/fluorescein_angriogram/vampire/datasets/__init__.py
@@ -0,0 +1,3 @@
+from .vampire_dataset import VampireDataset
+
+__all__ = ['VampireDataset']
diff --git a/projects/medical/2d_image/fluorescein_angriogram/vampire/datasets/vampire_dataset.py b/projects/medical/2d_image/fluorescein_angriogram/vampire/datasets/vampire_dataset.py
new file mode 100755
index 00000000000..4d38040f7f1
--- /dev/null
+++ b/projects/medical/2d_image/fluorescein_angriogram/vampire/datasets/vampire_dataset.py
@@ -0,0 +1,28 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class VampireDataset(BaseSegDataset):
+ """VampireDataset dataset.
+
+ In segmentation map annotation for VampireDataset, 0 stands for background,
+ which is included in 2 categories. ``reduce_zero_label`` is fixed to
+ False. The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is
+ fixed to '.png'.
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ """
+ METAINFO = dict(classes=('background', 'vessel'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
diff --git a/projects/medical/2d_image/fluorescein_angriogram/vampire/tools/prepare_dataset.py b/projects/medical/2d_image/fluorescein_angriogram/vampire/tools/prepare_dataset.py
new file mode 100644
index 00000000000..2755b5d28bd
--- /dev/null
+++ b/projects/medical/2d_image/fluorescein_angriogram/vampire/tools/prepare_dataset.py
@@ -0,0 +1,44 @@
+import os
+import shutil
+
+from PIL import Image
+
+path = 'data'
+
+if not os.path.exists(os.path.join(path, 'images', 'train')):
+ os.system(f'mkdir -p {os.path.join(path, "images", "train")}')
+
+if not os.path.exists(os.path.join(path, 'masks', 'train')):
+ os.system(f'mkdir -p {os.path.join(path, "masks", "train")}')
+
+origin_data_path = os.path.join(path, 'vesselSegmentation')
+
+imgs_amd14 = os.listdir(os.path.join(origin_data_path, 'AMD14'))
+imgs_ger7 = os.listdir(os.path.join(origin_data_path, 'GER7'))
+
+for img in imgs_amd14:
+ shutil.copy(
+ os.path.join(origin_data_path, 'AMD14', img),
+ os.path.join(path, 'images', 'train', img))
+ # copy GT
+ img_gt = img.replace('.png', '-GT.png')
+ shutil.copy(
+ os.path.join(origin_data_path, 'AMD14-GT', f'{img_gt}'),
+ os.path.join(path, 'masks', 'train', img))
+
+for img in imgs_ger7:
+ shutil.copy(
+ os.path.join(origin_data_path, 'GER7', img),
+ os.path.join(path, 'images', 'train', img))
+ # copy GT
+ img_gt = img.replace('.bmp', '-GT.png')
+ img = img.replace('bmp', 'png')
+ shutil.copy(
+ os.path.join(origin_data_path, 'GER7-GT', img_gt),
+ os.path.join(path, 'masks', 'train', img))
+
+imgs = os.listdir(os.path.join(path, 'images', 'train'))
+for img in imgs:
+ if not img.endswith('.png'):
+ im = Image.open(os.path.join(path, 'images', 'train', img))
+ im.save(os.path.join(path, 'images', 'train', img[:-4] + '.png'))
diff --git a/projects/medical/2d_image/fundus_photography/dr_hagis/README.md b/projects/medical/2d_image/fundus_photography/dr_hagis/README.md
new file mode 100644
index 00000000000..85d8a3e271c
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/dr_hagis/README.md
@@ -0,0 +1,155 @@
+# DR HAGIS: Diabetic Retinopathy, Hypertension, Age-related macular degeneration and Glacuoma ImageS
+
+## Description
+
+This project supports **`DR HAGIS: Diabetic Retinopathy, Hypertension, Age-related macular degeneration and Glacuoma ImageS`**, which can be downloaded from [here](https://paperswithcode.com/dataset/dr-hagis).
+
+### Dataset Overview
+
+The DR HAGIS database has been created to aid the development of vessel extraction algorithms suitable for retinal screening programmes. Researchers are encouraged to test their segmentation algorithms using this database. All thirty-nine fundus images were obtained from a diabetic retinopathy screening programme in the UK. Hence, all images were taken from diabetic patients.
+
+Besides the fundus images, the manual segmentation of the retinal surface vessels is provided by an expert grader. These manually segmented images can be used as the ground truth to compare and assess the automatic vessel extraction algorithms. Masks of the FOV are provided as well to quantify the accuracy of vessel extraction within the FOV only. The images were acquired in different screening centers, therefore reflecting the range of image resolutions, digital cameras and fundus cameras used in the clinic. The fundus images were captured using a Topcon TRC-NW6s, Topcon TRC-NW8 or a Canon CR DGi fundus camera with a horizontal 45 degree field-of-view (FOV). The images are 4752x3168 pixels, 3456x2304 pixels, 3126x2136 pixels, 2896x1944 pixels or 2816x1880 pixels in size.
+
+### Original Statistic Information
+
+| Dataset name | Anatomical region | Task type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
+| ------------------------------------------------------- | ----------------- | ------------ | ------------------ | ------------ | --------------------- | ---------------------- | ------------ | ------- |
+| [DR HAGIS](https://paperswithcode.com/dataset/dr-hagis) | head and neck | segmentation | fundus photography | 2 | 40/-/- | yes/-/- | 2017 | - |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 40 | 96.38 | - | - | - | - |
+| vessel | 40 | 3.62 | - | - | - | - |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![bac](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/fundus_photography/dr_hagis/dr_hagis_dataset.png)
+
+## Usage
+
+### Prerequisites
+
+- Python v3.8
+- PyTorch v1.10.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `dr_hagis/` root directory, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset preparing
+
+- download dataset from [here](https://paperswithcode.com/dataset/dr-hagis) and decompress data to path `'data/'`.
+- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below.
+- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set can't be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── fundus_photography
+ │ │ │ │ ├── dr_hagis
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── val.txt
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+```
+
+### Divided Dataset Information
+
+***Note: The table information below is divided by ourselves.***
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 32 | 96.21 | 8 | 97.12 | - | - |
+| vessel | 32 | 3.79 | 8 | 2.88 | - | - |
+
+### Training commands
+
+Train models on a single server with one GPU.
+
+```shell
+mim train mmseg ./configs/${CONFIG_FILE}
+```
+
+### Testing commands
+
+Test models on a single server with one GPU.
+
+```shell
+mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH}
+```
+
+
+
+## Dataset Citation
+
+If this work is helpful for your research, please consider citing the below paper.
+
+```
+@article{holm2017dr,
+ title={DR HAGIS—a fundus image database for the automatic extraction of retinal surface vessels from diabetic patients},
+ author={Holm, Sven and Russell, Greg and Nourrit, Vincent and McLoughlin, Niall},
+ journal={Journal of Medical Imaging},
+ volume={4},
+ number={1},
+ pages={014503--014503},
+ year={2017},
+ publisher={Society of Photo-Optical Instrumentation Engineers}
+}
+```
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+
+ - [x] Basic docstrings & proper citation
+
+ - [ ] Test-time correctness
+
+ - [x] A full README
+
+- [ ] Milestone 2: Indicates a successful model implementation.
+
+ - [ ] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+
+ - [ ] Unit tests
+
+ - [ ] Code polishing
+
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/fundus_photography/dr_hagis/configs/dr-hagis_512x512.py b/projects/medical/2d_image/fundus_photography/dr_hagis/configs/dr-hagis_512x512.py
new file mode 100644
index 00000000000..93b96384109
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/dr_hagis/configs/dr-hagis_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'DRHAGISDataset'
+data_root = 'data/'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='val.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/fundus_photography/dr_hagis/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_dr-hagis-512x512.py b/projects/medical/2d_image/fundus_photography/dr_hagis/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_dr-hagis-512x512.py
new file mode 100644
index 00000000000..9d14427c45f
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/dr_hagis/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_dr-hagis-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ './dr-hagis_512x512.py', 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.dr-hagis_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/fundus_photography/dr_hagis/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_dr-hagis-512x512.py b/projects/medical/2d_image/fundus_photography/dr_hagis/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_dr-hagis-512x512.py
new file mode 100644
index 00000000000..507ec748bf5
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/dr_hagis/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_dr-hagis-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ './dr-hagis_512x512.py', 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.dr-hagis_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/fundus_photography/dr_hagis/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_dr-hagis-512x512.py b/projects/medical/2d_image/fundus_photography/dr_hagis/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_dr-hagis-512x512.py
new file mode 100644
index 00000000000..092ae00a7d3
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/dr_hagis/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_dr-hagis-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ './dr-hagis_512x512.py', 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.dr-hagis_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/fundus_photography/dr_hagis/datasets/dr-hagis_dataset.py b/projects/medical/2d_image/fundus_photography/dr_hagis/datasets/dr-hagis_dataset.py
new file mode 100644
index 00000000000..9659f0b8d77
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/dr_hagis/datasets/dr-hagis_dataset.py
@@ -0,0 +1,27 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class DRHAGISDataset(BaseSegDataset):
+ """DRHAGISDataset dataset.
+
+ In segmentation map annotation for DRHAGISDataset,
+ ``reduce_zero_label`` is fixed to False. The ``img_suffix``
+ is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
+
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ """
+ METAINFO = dict(classes=('background', 'vessel'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=False,
+ **kwargs)
diff --git a/projects/medical/2d_image/fundus_photography/dr_hagis/tools/prepare_dataset.py b/projects/medical/2d_image/fundus_photography/dr_hagis/tools/prepare_dataset.py
new file mode 100755
index 00000000000..51f4df7dac2
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/dr_hagis/tools/prepare_dataset.py
@@ -0,0 +1,41 @@
+import glob
+import os
+import shutil
+
+import mmengine
+import numpy as np
+from PIL import Image
+
+root_path = 'data/'
+img_suffix = '.jpg'
+seg_map_suffix = '_manual_orig.png'
+save_img_suffix = '.png'
+save_seg_map_suffix = '.png'
+
+x_train = glob.glob(os.path.join('data/DRHAGIS/**/*' + img_suffix))
+
+mmengine.mkdir_or_exist(root_path + 'images/train/')
+mmengine.mkdir_or_exist(root_path + 'masks/train/')
+
+D3_palette = {0: (0, 0, 0), 1: (1, 1, 1)}
+D3_invert_palette = {v: k for k, v in D3_palette.items()}
+D2_255_convert_dict = {0: 0, 255: 1}
+
+part_dir_dict = {0: 'train/', 1: 'val/'}
+for ith, part in enumerate([x_train]):
+ part_dir = part_dir_dict[ith]
+ for img in part:
+ basename = os.path.basename(img)
+ shutil.copy(
+ img, root_path + 'images/' + part_dir + basename.split('.')[0] +
+ save_img_suffix)
+ mask_path = root_path + 'DRHAGIS/Manual_Segmentations/' + basename.split( # noqa
+ '.')[0] + seg_map_suffix
+ label = np.array(Image.open(mask_path))
+
+ save_mask_path = root_path + 'masks/' + part_dir + basename.split(
+ '.')[0] + save_seg_map_suffix # noqa
+ mask = np.array(Image.open(mask_path)).astype(np.uint8)
+ mask[mask == 255] = 1
+ mask = Image.fromarray(mask)
+ mask.save(save_mask_path)
diff --git a/projects/medical/2d_image/fundus_photography/gamma3/README.md b/projects/medical/2d_image/fundus_photography/gamma3/README.md
new file mode 100644
index 00000000000..e834508fcb5
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/gamma3/README.md
@@ -0,0 +1,167 @@
+# Glaucoma grAding from Multi-Modality imAges Task3
+
+## Description
+
+This project support **`Glaucoma grAding from Multi-Modality imAges Task3`**, and the dataset used in this project can be downloaded from [here](https://aistudio.baidu.com/aistudio/competition/detail/121/0/datasets).
+
+### Dataset Overview
+
+This regular-challenge dataset was provided by Sun Yat-sen Ophthalmic Center, Sun Yat-sen University, Guangzhou, China. The dataset contains 200 fundus color images: 100 pairs in the training set and 100 pairs in the test set.
+
+### Original Statistic Information
+
+| Dataset name | Anatomical region | Task type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
+| ----------------------------------------------------------------------------------- | ----------------- | ------------ | --------------- | ------------ | --------------------- | ---------------------- | ------------ | --------------------------------------------------------------- |
+| [GammaTask3](https://aistudio.baidu.com/aistudio/competition/detail/121/0/datasets) | eye | segmentation | fundus photophy | 3 | 100/-/100 | yes/-/- | 2021 | [CC-BY-NC 4.0](https://creativecommons.org/licenses/by-sa/4.0/) |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 100 | 99.02 | - | - | - | - |
+| optic disc | 100 | 0.67 | - | - | - | - |
+| optic cup | 100 | 0.31 | - | - | - | - |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![bac](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/fundus_photography/gamma3/gamma3_dataset.png)
+
+## Dataset Citation
+
+```bibtex
+@article{fu2018joint,
+ title={Joint optic disc and cup segmentation based on multi-label deep network and polar transformation},
+ author={Fu, Huazhu and Cheng, Jun and Xu, Yanwu and Wong, Damon Wing Kee and Liu, Jiang and Cao, Xiaochun},
+ journal={IEEE transactions on medical imaging},
+ volume={37},
+ number={7},
+ pages={1597--1605},
+ year={2018},
+ publisher={IEEE}
+}
+
+@article{sevastopolsky2017optic,
+ title={Optic disc and cup segmentation methods for glaucoma detection with modification of U-Net convolutional neural network},
+ author={Sevastopolsky, Artem},
+ journal={Pattern Recognition and Image Analysis},
+ volume={27},
+ pages={618--624},
+ year={2017},
+ publisher={Springer}
+}
+```
+
+### Prerequisites
+
+- Python v3.8
+- PyTorch v1.10.0
+- pillow(PIL) v9.3.0
+- scikit-learn(sklearn) v1.2.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `gammm3/` root directory, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset preparing
+
+- download dataset from [here](https://aistudio.baidu.com/aistudio/competition/detail/121/0/datasets) and decompression data to path `'data/'`.
+- run script `"python tools/prepare_dataset.py"` to split dataset and change folder structure as below.
+- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set can't be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── fundus_photography
+ │ │ │ │ ├── gamma3
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── val.txt
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ │ ├── test
+ │ │ │ │ | │ │ │ ├── yyy.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── yyy.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+```
+
+### Divided Dataset Information
+
+***Note: The table information below is divided by ourselves.***
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 80 | 99.01 | 20 | 99.07 | - | - |
+| optic disc | 80 | 0.68 | 20 | 0.63 | - | - |
+| optic cup | 80 | 0.32 | 20 | 0.31 | - | - |
+
+### Training commands
+
+To train models on a single server with one GPU. (default)
+
+```shell
+mim train mmseg ./configs/${CONFIG_PATH}
+```
+
+### Testing commands
+
+To test models on a single server with one GPU. (default)
+
+```shell
+mim test mmseg ./configs/${CONFIG_PATH} --checkpoint ${CHECKPOINT_PATH}
+```
+
+
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+
+ - [x] Basic docstrings & proper citation
+
+ - [ ] Test-time correctness
+
+ - [x] A full README
+
+- [ ] Milestone 2: Indicates a successful model implementation.
+
+ - [ ] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+
+ - [ ] Unit tests
+
+ - [ ] Code polishing
+
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/fundus_photography/gamma3/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_gamma3-512x512.py b/projects/medical/2d_image/fundus_photography/gamma3/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_gamma3-512x512.py
new file mode 100644
index 00000000000..0daac51e10f
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/gamma3/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_gamma3-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './gamma3_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.gamma3_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=3),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/fundus_photography/gamma3/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_gamma3-512x512.py b/projects/medical/2d_image/fundus_photography/gamma3/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_gamma3-512x512.py
new file mode 100644
index 00000000000..8a25cd0d266
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/gamma3/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_gamma3-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './gamma3_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.gamma3_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=3),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/fundus_photography/gamma3/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_gamma3-512x512.py b/projects/medical/2d_image/fundus_photography/gamma3/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_gamma3-512x512.py
new file mode 100644
index 00000000000..ea648438672
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/gamma3/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_gamma3-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './gamma3_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.gamma3_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=3),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/fundus_photography/gamma3/configs/gamma3_512x512.py b/projects/medical/2d_image/fundus_photography/gamma3/configs/gamma3_512x512.py
new file mode 100644
index 00000000000..d23ab55ca71
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/gamma3/configs/gamma3_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'Gamma3Dataset'
+data_root = 'data/'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='val.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/fundus_photography/gamma3/datasets/gamma3_dataset.py b/projects/medical/2d_image/fundus_photography/gamma3/datasets/gamma3_dataset.py
new file mode 100644
index 00000000000..56cbdd63e61
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/gamma3/datasets/gamma3_dataset.py
@@ -0,0 +1,30 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class Gamma3Dataset(BaseSegDataset):
+ """Gamma3Dataset dataset.
+
+ In segmentation map annotation for Gamma3Dataset,
+ ``reduce_zero_label`` is fixed to False. The ``img_suffix``
+ is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
+
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
+ Default to False.
+ """
+ METAINFO = dict(classes=('background', 'disc', 'cup'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
diff --git a/projects/medical/2d_image/fundus_photography/gamma3/tools/prepare_dataset.py b/projects/medical/2d_image/fundus_photography/gamma3/tools/prepare_dataset.py
new file mode 100644
index 00000000000..eb820b6b740
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/gamma3/tools/prepare_dataset.py
@@ -0,0 +1,107 @@
+import glob
+import os
+
+import numpy as np
+from PIL import Image
+
+root_path = 'data/'
+img_suffix = '.jpg'
+seg_map_suffix = '.png'
+save_img_suffix = '.png'
+save_seg_map_suffix = '.png'
+tgt_img_train_dir = os.path.join(root_path, 'images/train/')
+tgt_mask_train_dir = os.path.join(root_path, 'masks/train/')
+tgt_img_test_dir = os.path.join(root_path, 'images/test/')
+os.system('mkdir -p ' + tgt_img_train_dir)
+os.system('mkdir -p ' + tgt_mask_train_dir)
+os.system('mkdir -p ' + tgt_img_test_dir)
+
+
+def filter_suffix_recursive(src_dir, suffix):
+ # filter out file names and paths in source directory
+ suffix = '.' + suffix if '.' not in suffix else suffix
+ file_paths = glob.glob(
+ os.path.join(src_dir, '**/*' + suffix), recursive=True)
+ file_names = [_.split('/')[-1] for _ in file_paths]
+ return sorted(file_paths), sorted(file_names)
+
+
+def convert_label(img, convert_dict):
+ arr = np.zeros_like(img, dtype=np.uint8)
+ for c, i in convert_dict.items():
+ arr[img == c] = i
+ return arr
+
+
+def convert_pics_into_pngs(src_dir, tgt_dir, suffix, convert='RGB'):
+ if not os.path.exists(tgt_dir):
+ os.makedirs(tgt_dir)
+ src_paths, src_names = filter_suffix_recursive(src_dir, suffix=suffix)
+
+ for i, (src_name, src_path) in enumerate(zip(src_names, src_paths)):
+ tgt_name = src_name.replace(suffix, save_img_suffix)
+ tgt_path = os.path.join(tgt_dir, tgt_name)
+ num = len(src_paths)
+ img = np.array(Image.open(src_path))
+ if len(img.shape) == 2:
+ pil = Image.fromarray(img).convert(convert)
+ elif len(img.shape) == 3:
+ pil = Image.fromarray(img)
+ else:
+ raise ValueError('Input image not 2D/3D: ', img.shape)
+
+ pil.save(tgt_path)
+ print(f'processed {i+1}/{num}.')
+
+
+def convert_label_pics_into_pngs(src_dir,
+ tgt_dir,
+ suffix,
+ convert_dict={
+ 0: 2,
+ 128: 1,
+ 255: 0
+ }):
+ if not os.path.exists(tgt_dir):
+ os.makedirs(tgt_dir)
+
+ src_paths, src_names = filter_suffix_recursive(src_dir, suffix=suffix)
+ num = len(src_paths)
+ for i, (src_name, src_path) in enumerate(zip(src_names, src_paths)):
+ tgt_name = src_name.replace(suffix, save_seg_map_suffix)
+ tgt_path = os.path.join(tgt_dir, tgt_name)
+
+ img = np.array(Image.open(src_path))
+ img = convert_label(img, convert_dict)
+ Image.fromarray(img).save(tgt_path)
+ print(f'processed {i+1}/{num}.')
+
+
+if __name__ == '__main__':
+
+ convert_pics_into_pngs(
+ os.path.join(
+ root_path,
+ 'task3_disc_cup_segmentation/training/fundus color images/'),
+ tgt_img_train_dir,
+ suffix=img_suffix)
+
+ convert_pics_into_pngs(
+ os.path.join(
+ root_path,
+ 'task3_disc_cup_segmentation/testing/fundus color images/'),
+ tgt_img_test_dir,
+ suffix=img_suffix)
+
+ convert_label_pics_into_pngs(
+ os.path.join(root_path,
+ 'task3_disc_cup_segmentation/training/Disc_Cup_Mask/'),
+ tgt_mask_train_dir,
+ suffix=seg_map_suffix,
+ convert_dict={
+ 0: 2,
+ 128: 1,
+ 255: 0
+ })
+ # original: [0, 128, 255] for ['optic cup', 'optic disc', 'background']
+ # converted: [0, 1, 2] for ['background', 'optic disc', 'optic cup']
diff --git a/projects/medical/2d_image/fundus_photography/orvs/README.md b/projects/medical/2d_image/fundus_photography/orvs/README.md
new file mode 100644
index 00000000000..6f09203ac4d
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/orvs/README.md
@@ -0,0 +1,140 @@
+# ORVS (Online Retinal image for Vessel Segmentation (ORVS))
+
+## Description
+
+This project supports **`ORVS (Online Retinal image for Vessel Segmentation (ORVS))`**, which can be downloaded from [here](https://opendatalab.org.cn/ORVS).
+
+### Dataset Overview
+
+The ORVS dataset is a newly established collaboration between the Department of Computer Science and the Department of Vision Science at the University of Calgary. The dataset contains 49 images collected from a clinic in Calgary, Canada, consisting of 42 training images and 7 testing images. All images were obtained using a Zeiss Visucam 200 with a 30-degree field of view (FOV). The image size is 1444×1444 pixels with 24 bits per pixel. The images are stored in JPEG format with low compression, which is common in ophthalmic practice. All images were manually traced by an expert who has been working in the field of retinal image analysis and has been trained to mark all pixels belonging to retinal vessels. The Windows Paint 3D tool was used for manual image annotation.
+
+### Original Statistic Information
+
+| Dataset name | Anatomical region | Task type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
+| ------------------------------------------------------ | ----------------- | ------------ | ------------------ | ------------ | --------------------- | ---------------------- | ------------ | ------- |
+| [Bactteria detection](https://opendatalab.org.cn/ORVS) | bacteria | segmentation | fundus photography | 2 | 130/-/72 | yes/-/yes | 2020 | - |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 130 | 94.83 | - | - | 72 | 94.25 |
+| vessel | 130 | 5.17 | - | - | 72 | 5.75 |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![bac](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/fundus_photography/orvs/ORVS_dataset.png)
+
+### Prerequisites
+
+- Python v3.8
+- PyTorch v1.10.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `orvs/` root directory, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset preparing
+
+- Clone this [repository](https://github.com/AbdullahSarhan/ICPRVessels), then move `Vessels-Datasets` to `data/`.
+- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below.
+- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set can't be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── fundus_photography
+ │ │ │ │ ├── orvs
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── test.txt
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+```
+
+### Training commands
+
+Train models on a single server with one GPU.
+
+```shell
+mim train mmseg ./configs/${CONFIG_FILE}
+```
+
+### Testing commands
+
+Test models on a single server with one GPU.
+
+```shell
+mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH}
+```
+
+
+
+## Dataset Citation
+
+If this work is helpful for your research, please consider citing the below paper.
+
+```
+@inproceedings{sarhan2021transfer,
+ title={Transfer learning through weighted loss function and group normalization for vessel segmentation from retinal images},
+ author={Sarhan, Abdullah and Rokne, Jon and Alhajj, Reda and Crichton, Andrew},
+ booktitle={2020 25th International Conference on Pattern Recognition (ICPR)},
+ pages={9211--9218},
+ year={2021},
+ organization={IEEE}
+}
+```
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+
+ - [x] Basic docstrings & proper citation
+
+ - [ ] Test-time correctness
+
+ - [x] A full README
+
+- [ ] Milestone 2: Indicates a successful model implementation.
+
+ - [ ] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+
+ - [ ] Unit tests
+
+ - [ ] Code polishing
+
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/fundus_photography/orvs/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_orvs-512x512.py b/projects/medical/2d_image/fundus_photography/orvs/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_orvs-512x512.py
new file mode 100644
index 00000000000..662f837158a
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/orvs/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_orvs-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ './orvs_512x512.py', 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.orvs_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/fundus_photography/orvs/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_orvs-512x512.py b/projects/medical/2d_image/fundus_photography/orvs/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_orvs-512x512.py
new file mode 100644
index 00000000000..c47cdb6b24d
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/orvs/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_orvs-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ './orvs_512x512.py', 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.orvs_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/fundus_photography/orvs/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_orvs-512x512.py b/projects/medical/2d_image/fundus_photography/orvs/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_orvs-512x512.py
new file mode 100644
index 00000000000..1097aade286
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/orvs/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_orvs-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ './orvs_512x512.py', 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.orvs_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/fundus_photography/orvs/configs/orvs_512x512.py b/projects/medical/2d_image/fundus_photography/orvs/configs/orvs_512x512.py
new file mode 100644
index 00000000000..a5594dec388
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/orvs/configs/orvs_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'ORVSDataset'
+data_root = 'data/'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='test.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/fundus_photography/orvs/datasets/orvs_dataset.py b/projects/medical/2d_image/fundus_photography/orvs/datasets/orvs_dataset.py
new file mode 100644
index 00000000000..e915ae4cd2b
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/orvs/datasets/orvs_dataset.py
@@ -0,0 +1,27 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class ORVSDataset(BaseSegDataset):
+ """ORVSDataset dataset.
+
+ In segmentation map annotation for ORVSDataset,
+ ``reduce_zero_label`` is fixed to False. The ``img_suffix``
+ is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
+
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ """
+ METAINFO = dict(classes=('background', 'vessel'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=False,
+ **kwargs)
diff --git a/projects/medical/2d_image/fundus_photography/orvs/tools/prepare_dataset.py b/projects/medical/2d_image/fundus_photography/orvs/tools/prepare_dataset.py
new file mode 100755
index 00000000000..f902d871010
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/orvs/tools/prepare_dataset.py
@@ -0,0 +1,55 @@
+import glob
+import os
+
+import numpy as np
+from PIL import Image
+
+root_path = 'data/'
+img_suffix = '.jpg'
+seg_map_suffix_list = ['.jpg', '.png', '.tif']
+save_img_suffix = '.png'
+save_seg_map_suffix = '.png'
+
+x_train = glob.glob(
+ os.path.join('data/Vessels-Datasets/*/Train/Original/Images/*' +
+ img_suffix))
+x_test = glob.glob(
+ os.path.join('data/Vessels-Datasets/*/Test/Original/Images/*' +
+ img_suffix))
+
+os.system('mkdir -p ' + root_path + 'images/train/')
+os.system('mkdir -p ' + root_path + 'images/test/')
+os.system('mkdir -p ' + root_path + 'masks/train/')
+os.system('mkdir -p ' + root_path + 'masks/test/')
+
+part_dir_dict = {0: 'train/', 1: 'test/'}
+for ith, part in enumerate([x_train, x_test]):
+ part_dir = part_dir_dict[ith]
+ for img in part:
+ type_name = img.split('/')[-5]
+ basename = type_name + '_' + os.path.basename(img)
+ save_img_path = root_path + 'images/' + part_dir + basename.split(
+ '.')[0] + save_img_suffix
+ Image.open(img).save(save_img_path)
+
+ for seg_map_suffix in seg_map_suffix_list:
+ if os.path.exists('/'.join(img.split('/')[:-1]).replace(
+ 'Images', 'Labels')):
+ mask_path = img.replace('Images', 'Labels').replace(
+ img_suffix, seg_map_suffix)
+ else:
+ mask_path = img.replace('Images', 'labels').replace(
+ img_suffix, seg_map_suffix)
+ if os.path.exists(mask_path):
+ break
+ save_mask_path = root_path + 'masks/' + part_dir + basename.split(
+ '.')[0] + save_seg_map_suffix
+ masks = np.array(Image.open(mask_path).convert('L')).astype(np.uint8)
+ if len(np.unique(masks)) == 2 and 1 in np.unique(masks):
+ print(np.unique(masks))
+ pass
+ else:
+ masks[masks < 128] = 0
+ masks[masks >= 128] = 1
+ masks = Image.fromarray(masks)
+ masks.save(save_mask_path)
diff --git a/projects/medical/2d_image/fundus_photography/rite/README.md b/projects/medical/2d_image/fundus_photography/rite/README.md
new file mode 100644
index 00000000000..0aea9b00d17
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/rite/README.md
@@ -0,0 +1,135 @@
+# Retinal Images vessel Tree Extraction (RITE)
+
+## Description
+
+This project supports **`Retinal Images vessel Tree Extraction (RITE) `**, which can be downloaded from [here](https://opendatalab.com/RITE).
+
+### Dataset Overview
+
+The RITE (Retinal Images vessel Tree Extraction) is a database that enables comparative studies on segmentation or classification of arteries and veins on retinal fundus images, which is established based on the public available DRIVE database (Digital Retinal Images for Vessel Extraction). RITE contains 40 sets of images, equally separated into a training subset and a test subset, the same as DRIVE. The two subsets are built from the corresponding two subsets in DRIVE. For each set, there is a fundus photograph, a vessel reference standard. The fundus photograph is inherited from DRIVE. For the training set, the vessel reference standard is a modified version of 1st_manual from DRIVE. For the test set, the vessel reference standard is 2nd_manual from DRIVE.
+
+### Statistic Information
+
+| Dataset Name | Anatomical Region | Task Type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
+| ------------------------------------ | ----------------- | ------------ | ------------------ | ------------ | --------------------- | ---------------------- | ------------ | --------------------------------------------------------------- |
+| [Rite](https://opendatalab.com/RITE) | head_and_neck | segmentation | fundus_photography | 2 | 20/-/20 | yes/-/yes | 2013 | [CC-BY-NC 4.0](https://creativecommons.org/licenses/by-sa/4.0/) |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 20 | 91.61 | - | - | 20 | 91.58 |
+| vessel | 20 | 8.39 | - | - | 20 | 8.42 |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![rite](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/fundus_photography/rite/rite_dataset.png?raw=true)
+
+### Dataset Citation
+
+```
+@InProceedings{10.1007/978-3-642-40763-5_54,
+ author={Hu, Qiao and Abr{\`a}moff, Michael D. and Garvin, Mona K.},
+ title={Automated Separation of Binary Overlapping Trees in Low-Contrast Color Retinal Images},
+ booktitle={Medical Image Computing and Computer-Assisted Intervention -- MICCAI 2013},
+ year={2013},
+ pages={436--443},
+}
+
+
+```
+
+### Prerequisites
+
+- Python v3.8
+- PyTorch v1.10.0
+- pillow(PIL) v9.3.0 9.3.0
+- scikit-learn(sklearn) v1.2.0 1.2.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `rite/` root directory, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset Preparing
+
+- download dataset from [here](https://opendatalab.com/RITE) and decompress data to path `'data/'`.
+- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below.
+- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set cannot be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── fundus_photography
+ │ │ │ │ ├── rite
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── val.txt
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+```
+
+### Training commands
+
+To train models on a single server with one GPU. (default)
+
+```shell
+mim train mmseg ./configs/${CONFIG_FILE}
+```
+
+### Testing commands
+
+To test models on a single server with one GPU. (default)
+
+```shell
+mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH}
+```
+
+
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+ - [x] Basic docstrings & proper citation
+ - [ ] Test-time correctness
+ - [x] A full README
+
+- [ ] Milestone 2: Indicates a successful model implementation.
+
+ - [ ] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+ - [ ] Unit tests
+ - [ ] Code polishing
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/fundus_photography/rite/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_rite-512x512.py b/projects/medical/2d_image/fundus_photography/rite/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_rite-512x512.py
new file mode 100644
index 00000000000..27dd4363b16
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/rite/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_rite-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './rite_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.rite_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/fundus_photography/rite/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_rite-512x512.py b/projects/medical/2d_image/fundus_photography/rite/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_rite-512x512.py
new file mode 100644
index 00000000000..48f6f973a1a
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/rite/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_rite-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './rite_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.rite_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/fundus_photography/rite/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_rite-512x512.py b/projects/medical/2d_image/fundus_photography/rite/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_rite-512x512.py
new file mode 100644
index 00000000000..5f5b24ba6a4
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/rite/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_rite-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './rite_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.rite_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/fundus_photography/rite/configs/fcn-unet-s5-d16_unet_1xb16-0.01lr-sigmoid-20k_rite-512x512.py b/projects/medical/2d_image/fundus_photography/rite/configs/fcn-unet-s5-d16_unet_1xb16-0.01lr-sigmoid-20k_rite-512x512.py
new file mode 100644
index 00000000000..bf66b6f320c
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/rite/configs/fcn-unet-s5-d16_unet_1xb16-0.01lr-sigmoid-20k_rite-512x512.py
@@ -0,0 +1,18 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './rite_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.rite_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(
+ num_classes=2, loss_decode=dict(use_sigmoid=True), out_channels=1),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/fundus_photography/rite/configs/rite_512x512.py b/projects/medical/2d_image/fundus_photography/rite/configs/rite_512x512.py
new file mode 100644
index 00000000000..02f620c665f
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/rite/configs/rite_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'RITEDataset'
+data_root = 'data/'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='test.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/fundus_photography/rite/datasets/rite_dataset.py b/projects/medical/2d_image/fundus_photography/rite/datasets/rite_dataset.py
new file mode 100644
index 00000000000..99f688de949
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/rite/datasets/rite_dataset.py
@@ -0,0 +1,31 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class RITEDataset(BaseSegDataset):
+ """RITEDataset dataset.
+
+ In segmentation map annotation for RITEDataset,
+ 0 stands for background, which is included in 2 categories.
+ ``reduce_zero_label`` is fixed to False. The ``img_suffix``
+ is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
+
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
+ Default to False.
+ """
+ METAINFO = dict(classes=('background', 'vessel'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
diff --git a/projects/medical/2d_image/fundus_photography/rite/tools/prepare_dataset.py b/projects/medical/2d_image/fundus_photography/rite/tools/prepare_dataset.py
new file mode 100644
index 00000000000..ca7e996961f
--- /dev/null
+++ b/projects/medical/2d_image/fundus_photography/rite/tools/prepare_dataset.py
@@ -0,0 +1,98 @@
+import glob
+import os
+
+import numpy as np
+from PIL import Image
+
+root_path = 'data/'
+img_suffix = '.tif'
+seg_map_suffix = '.png'
+save_img_suffix = '.png'
+save_seg_map_suffix = '.png'
+src_img_train_dir = os.path.join(root_path, 'AV_groundTruth/training/images/')
+src_img_test_dir = os.path.join(root_path, 'AV_groundTruth/test/images/')
+src_mask_train_dir = os.path.join(root_path, 'AV_groundTruth/training/vessel/')
+src_mask_test_dir = os.path.join(root_path, 'AV_groundTruth/test/vessel/')
+
+tgt_img_train_dir = os.path.join(root_path, 'images/train/')
+tgt_mask_train_dir = os.path.join(root_path, 'masks/train/')
+tgt_img_test_dir = os.path.join(root_path, 'images/test/')
+tgt_mask_test_dir = os.path.join(root_path, 'masks/test/')
+os.system('mkdir -p ' + tgt_img_train_dir)
+os.system('mkdir -p ' + tgt_mask_train_dir)
+os.system('mkdir -p ' + tgt_img_test_dir)
+os.system('mkdir -p ' + tgt_mask_test_dir)
+
+
+def filter_suffix_recursive(src_dir, suffix):
+ # filter out file names and paths in source directory
+ suffix = '.' + suffix if '.' not in suffix else suffix
+ file_paths = glob.glob(
+ os.path.join(src_dir, '**', '*' + suffix), recursive=True)
+ file_names = [_.split('/')[-1] for _ in file_paths]
+ return sorted(file_paths), sorted(file_names)
+
+
+def convert_label(img, convert_dict):
+ arr = np.zeros_like(img, dtype=np.uint8)
+ for c, i in convert_dict.items():
+ arr[img == c] = i
+ return arr
+
+
+def convert_pics_into_pngs(src_dir, tgt_dir, suffix, convert='RGB'):
+ if not os.path.exists(tgt_dir):
+ os.makedirs(tgt_dir)
+
+ src_paths, src_names = filter_suffix_recursive(src_dir, suffix=suffix)
+ for i, (src_name, src_path) in enumerate(zip(src_names, src_paths)):
+ tgt_name = src_name.replace(suffix, save_img_suffix)
+ tgt_path = os.path.join(tgt_dir, tgt_name)
+ num = len(src_paths)
+ img = np.array(Image.open(src_path))
+ if len(img.shape) == 2:
+ pil = Image.fromarray(img).convert(convert)
+ elif len(img.shape) == 3:
+ pil = Image.fromarray(img)
+ else:
+ raise ValueError('Input image not 2D/3D: ', img.shape)
+
+ pil.save(tgt_path)
+ print(f'processed {i+1}/{num}.')
+
+
+def convert_label_pics_into_pngs(src_dir,
+ tgt_dir,
+ suffix,
+ convert_dict={
+ 0: 0,
+ 255: 1
+ }):
+ if not os.path.exists(tgt_dir):
+ os.makedirs(tgt_dir)
+
+ src_paths, src_names = filter_suffix_recursive(src_dir, suffix=suffix)
+ num = len(src_paths)
+ for i, (src_name, src_path) in enumerate(zip(src_names, src_paths)):
+ tgt_name = src_name.replace(suffix, save_seg_map_suffix)
+ tgt_path = os.path.join(tgt_dir, tgt_name)
+
+ img = np.array(Image.open(src_path))
+ img = convert_label(img, convert_dict)
+ Image.fromarray(img).save(tgt_path)
+ print(f'processed {i+1}/{num}.')
+
+
+if __name__ == '__main__':
+
+ convert_pics_into_pngs(
+ src_img_train_dir, tgt_img_train_dir, suffix=img_suffix)
+
+ convert_pics_into_pngs(
+ src_img_test_dir, tgt_img_test_dir, suffix=img_suffix)
+
+ convert_label_pics_into_pngs(
+ src_mask_train_dir, tgt_mask_train_dir, suffix=seg_map_suffix)
+
+ convert_label_pics_into_pngs(
+ src_mask_test_dir, tgt_mask_test_dir, suffix=seg_map_suffix)
diff --git a/projects/medical/2d_image/histopathology/breastCancerCellSegmentation/README.md b/projects/medical/2d_image/histopathology/breastCancerCellSegmentation/README.md
new file mode 100644
index 00000000000..97c4a0f0e55
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/breastCancerCellSegmentation/README.md
@@ -0,0 +1,123 @@
+# breastCancerCellSegmentation
+
+## Description
+
+This project supports **`breastCancerCellSegmentation`**, which can be downloaded from [here](https://www.heywhale.com/mw/dataset/5e9e9b35ebb37f002c625423).
+
+### Dataset Overview
+
+This dataset, with 58 H&E-stained histopathology images was used for breast cancer cell detection and associated real-world data.
+Conventional histology uses a combination of hematoxylin and eosin stains, commonly referred to as H&E. These images are stained because most cells are inherently transparent with little or no intrinsic pigment.
+Certain special stains selectively bind to specific components and can be used to identify biological structures such as cells.
+
+### Original Statistic Information
+
+| Dataset name | Anatomical region | Task type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
+| -------------------------------------------------------------------------------------------- | ----------------- | ------------ | -------------- | ------------ | --------------------- | ---------------------- | ------------ | --------------------------------------------------------------- |
+| [breastCancerCellSegmentation](https://www.heywhale.com/mw/dataset/5e9e9b35ebb37f002c625423) | cell | segmentation | histopathology | 2 | 58/-/- | yes/-/- | 2020 | [CC-BY-NC 4.0](https://creativecommons.org/licenses/by-sa/4.0/) |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 58 | 98.37 | - | - | - | - |
+| breastCancerCell | 58 | 1.63 | - | - | - | - |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![bac](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/histopathology/breastCancerCellSegmentation/breastCancerCellSegmentation_dataset.png)
+
+## Usage
+
+### Prerequisites
+
+- Python v3.8
+- PyTorch v1.10.0
+- pillow (PIL) v9.3.0
+- scikit-learn (sklearn) v1.2.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `breastCancerCellSegmentation/` root directory, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset Preparing
+
+- Download dataset from [here](https://www.heywhale.com/mw/dataset/5e9e9b35ebb37f002c625423) and save it to the `data/` directory .
+- Decompress data to path `data/`. This will create a new folder named `data/breastCancerCellSegmentation/`, which contains the original image data.
+- run script `python tools/prepare_dataset.py` to format data and change folder structure as below.
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── histopathology
+ │ │ │ │ ├── breastCancerCellSegmentation
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── breastCancerCellSegmentation
+ | │ │ │ │ │ │ ├── train.txt
+ | │ │ │ │ │ │ ├── val.txt
+ | │ │ │ │ │ │ ├── images
+ | │ │ │ │ │ │ | ├── xxx.tif
+ | │ │ │ │ │ │ ├── masks
+ | │ │ │ │ │ │ | ├── xxx.TIF
+
+```
+
+### Training commands
+
+Train models on a single server with one GPU.
+
+```shell
+mim train mmseg ./configs/${CONFIG_FILE}
+```
+
+### Testing commands
+
+Test models on a single server with one GPU.
+
+```shell
+mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH}
+```
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+
+ - [x] Basic docstrings & proper citation
+
+ - [x] Test-time correctness
+
+ - [x] A full README
+
+- [ ] Milestone 2: Indicates a successful model implementation.
+
+ - [ ] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+
+ - [ ] Unit tests
+
+ - [ ] Code polishing
+
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/histopathology/breastCancerCellSegmentation/configs/breastCancerCellSegmentation_512x512.py b/projects/medical/2d_image/histopathology/breastCancerCellSegmentation/configs/breastCancerCellSegmentation_512x512.py
new file mode 100644
index 00000000000..1cf0fccf5be
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/breastCancerCellSegmentation/configs/breastCancerCellSegmentation_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'breastCancerCellSegmentationDataset'
+data_root = 'data/breastCancerCellSegmentation'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile', imdecode_backend='tifffile'),
+ dict(type='LoadAnnotations', imdecode_backend='tifffile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile', imdecode_backend='tifffile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations', imdecode_backend='tifffile'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images', seg_map_path='masks'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='val.txt',
+ data_prefix=dict(img_path='images', seg_map_path='masks'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/histopathology/breastCancerCellSegmentation/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_breastCancerCellSegmentation-512x512.py b/projects/medical/2d_image/histopathology/breastCancerCellSegmentation/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_breastCancerCellSegmentation-512x512.py
new file mode 100644
index 00000000000..55d17089686
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/breastCancerCellSegmentation/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_breastCancerCellSegmentation-512x512.py
@@ -0,0 +1,18 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ './breastCancerCellSegmentation_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.breastCancerCellSegmentation_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/breastCancerCellSegmentation/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_breastCancerCellSegmentation-512x512.py b/projects/medical/2d_image/histopathology/breastCancerCellSegmentation/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_breastCancerCellSegmentation-512x512.py
new file mode 100644
index 00000000000..cf28aad739f
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/breastCancerCellSegmentation/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_breastCancerCellSegmentation-512x512.py
@@ -0,0 +1,18 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ './breastCancerCellSegmentation_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.breastCancerCellSegmentation_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/breastCancerCellSegmentation/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_breastCancerCellSegmentation-512x512.py b/projects/medical/2d_image/histopathology/breastCancerCellSegmentation/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_breastCancerCellSegmentation-512x512.py
new file mode 100644
index 00000000000..29aaff38941
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/breastCancerCellSegmentation/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_breastCancerCellSegmentation-512x512.py
@@ -0,0 +1,18 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ './breastCancerCellSegmentation_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.breastCancerCellSegmentation_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/breastCancerCellSegmentation/datasets/breastCancerCellSegmentation_dataset.py b/projects/medical/2d_image/histopathology/breastCancerCellSegmentation/datasets/breastCancerCellSegmentation_dataset.py
new file mode 100644
index 00000000000..eeceb6318c0
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/breastCancerCellSegmentation/datasets/breastCancerCellSegmentation_dataset.py
@@ -0,0 +1,30 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class breastCancerCellSegmentationDataset(BaseSegDataset):
+ """breastCancerCellSegmentationDataset dataset.
+
+ In segmentation map annotation for breastCancerCellSegmentationDataset,
+ ``reduce_zero_label`` is fixed to False. The ``img_suffix``
+ is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
+
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
+ Default to False.
+ """
+ METAINFO = dict(classes=('background', 'breastCancerCell'))
+
+ def __init__(self,
+ img_suffix='_ccd.tif',
+ seg_map_suffix='.TIF',
+ reduce_zero_label=False,
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
diff --git a/projects/medical/2d_image/histopathology/breastCancerCellSegmentation/tools/prepare_dataset.py b/projects/medical/2d_image/histopathology/breastCancerCellSegmentation/tools/prepare_dataset.py
new file mode 100644
index 00000000000..09cc689c862
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/breastCancerCellSegmentation/tools/prepare_dataset.py
@@ -0,0 +1,36 @@
+import argparse
+import glob
+import os
+
+from sklearn.model_selection import train_test_split
+
+
+def save_anno(img_list, file_path, suffix):
+ # 只保留文件名,不保留后缀
+ img_list = [x.split('/')[-1][:-len(suffix)] for x in img_list]
+
+ with open(file_path, 'w') as file_:
+ for x in list(img_list):
+ file_.write(x + '\n')
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--data_root', default='data/breastCancerCellSegmentation/')
+ args = parser.parse_args()
+ data_root = args.data_root
+
+ # 1. 划分训练集、验证集
+ # 1.1 获取所有图片路径
+ img_list = glob.glob(os.path.join(data_root, 'images', '*.tif'))
+ img_list.sort()
+ mask_list = glob.glob(os.path.join(data_root, 'masks', '*.TIF'))
+ mask_list.sort()
+ assert len(img_list) == len(mask_list)
+ # 1.2 划分训练集、验证集、测试集
+ train_img_list, val_img_list, train_mask_list, val_mask_list = train_test_split( # noqa
+ img_list, mask_list, test_size=0.2, random_state=42)
+ # 1.3 保存划分结果
+ save_anno(train_img_list, os.path.join(data_root, 'train.txt'), '_ccd.tif')
+ save_anno(val_img_list, os.path.join(data_root, 'val.txt'), '_ccd.tif')
diff --git a/projects/medical/2d_image/histopathology/breast_cancer_cell_seg/README.md b/projects/medical/2d_image/histopathology/breast_cancer_cell_seg/README.md
new file mode 100644
index 00000000000..b6f1ca63419
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/breast_cancer_cell_seg/README.md
@@ -0,0 +1,147 @@
+# Breast Cancer Cell Segmentation
+
+## Description
+
+This project support **`Breast Cancer Cell Segmentation`**, and the dataset used in this project can be downloaded from [here](https://tianchi.aliyun.com/dataset/dataDetail?dataId=90152).
+
+### Dataset Overview
+
+In this dataset, there are 58 H&E stained histopathology images used in breast cancer cell detection with associated ground truth data available. Routine histology uses the stain combination of hematoxylin and eosin, commonly referred to as H&E. These images are stained since most cells are essentially transparent, with little or no intrinsic pigment. Certain special stains, which bind selectively to particular components, are be used to identify biological structures such as cells. In those images, the challenging problem is cell segmentation for subsequent classification in benign and malignant cells.
+
+### Original Statistic Information
+
+| Dataset name | Anatomical region | Task type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
+| --------------------------------------------------------------------------------------------- | ----------------- | ------------ | -------------- | ------------ | --------------------- | ---------------------- | ------------ | ------------------------------------------------------------------------------------------------------ |
+| [Breast Cancer Cell Segmentation](https://tianchi.aliyun.com/dataset/dataDetail?dataId=90152) | thorax | segmentation | histopathology | 2 | 58/-/- | yes/-/- | 2021 | [CC-BY-SA-NC 4.0](http://creativecommons.org/licenses/by-sa/4.0/?spm=5176.12282016.0.0.3f5b5291ypBxb2) |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :----------------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| normal | 58 | 98.37 | - | - | - | - |
+| breast cancer cell | 58 | 1.63 | - | - | - | - |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![bac](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/histopathology/breast_cancer_cell_seg/breast_cancer_cell_seg_dataset.png)
+
+## Dataset Citation
+
+```
+@inproceedings{gelasca2008evaluation,
+ title={Evaluation and benchmark for biological image segmentation},
+ author={Gelasca, Elisa Drelie and Byun, Jiyun and Obara, Boguslaw and Manjunath, BS},
+ booktitle={2008 15th IEEE international conference on image processing},
+ pages={1816--1819},
+ year={2008},
+ organization={IEEE}
+}
+```
+
+### Prerequisites
+
+- Python v3.8
+- PyTorch v1.10.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `breast_cancer_cell_seg/` root directory, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset preparing
+
+- download dataset from [here](https://tianchi.aliyun.com/dataset/dataDetail?dataId=90152) and decompression data to path `'data/'`.
+- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below.
+- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set can't be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── histopathology
+ │ │ │ │ ├── breast_cancer_cell_seg
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── val.txt
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+```
+
+### Divided Dataset Information
+
+***Note: The table information below is divided by ourselves.***
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :----------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 46 | 98.36 | 12 | 98.41 | - | - |
+| erythrocytes | 46 | 1.64 | 12 | 1.59 | - | - |
+
+### Training commands
+
+Train models on a single server with one GPU.
+
+```shell
+mim train mmseg ./configs/${CONFIG_FILE}
+```
+
+### Testing commands
+
+Test models on a single server with one GPU.
+
+```shell
+mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH}
+```
+
+
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+
+ - [x] Basic docstrings & proper citation
+
+ - [x] Test-time correctness
+
+ - [x] A full README
+
+- [ ] Milestone 2: Indicates a successful model implementation.
+
+ - [ ] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+
+ - [ ] Unit tests
+
+ - [ ] Code polishing
+
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/histopathology/breast_cancer_cell_seg/configs/breast-cancer-cell-seg_512x512.py b/projects/medical/2d_image/histopathology/breast_cancer_cell_seg/configs/breast-cancer-cell-seg_512x512.py
new file mode 100644
index 00000000000..ead40e4345c
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/breast_cancer_cell_seg/configs/breast-cancer-cell-seg_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'BreastCancerCellSegDataset'
+data_root = 'data/'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='val.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/histopathology/breast_cancer_cell_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_breast-cancer-cell-seg-512x512.py b/projects/medical/2d_image/histopathology/breast_cancer_cell_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_breast-cancer-cell-seg-512x512.py
new file mode 100644
index 00000000000..691a0ff613d
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/breast_cancer_cell_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_breast-cancer-cell-seg-512x512.py
@@ -0,0 +1,18 @@
+_base_ = [
+ './breast-cancer-cell-seg_512x512.py',
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.breast-cancer-cell-seg_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/breast_cancer_cell_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_breast-cancer-cell-seg-512x512.py b/projects/medical/2d_image/histopathology/breast_cancer_cell_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_breast-cancer-cell-seg-512x512.py
new file mode 100644
index 00000000000..719b767ab1f
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/breast_cancer_cell_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_breast-cancer-cell-seg-512x512.py
@@ -0,0 +1,18 @@
+_base_ = [
+ './breast-cancer-cell-seg_512x512.py',
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.breast-cancer-cell-seg_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/breast_cancer_cell_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_breast-cancer-cell-seg-512x512.py b/projects/medical/2d_image/histopathology/breast_cancer_cell_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_breast-cancer-cell-seg-512x512.py
new file mode 100644
index 00000000000..9dfe70f761f
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/breast_cancer_cell_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_breast-cancer-cell-seg-512x512.py
@@ -0,0 +1,18 @@
+_base_ = [
+ './breast-cancer-cell-seg_512x512.py',
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.breast-cancer-cell-seg_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/breast_cancer_cell_seg/datasets/breast-cancer-cell-seg_dataset.py b/projects/medical/2d_image/histopathology/breast_cancer_cell_seg/datasets/breast-cancer-cell-seg_dataset.py
new file mode 100644
index 00000000000..6f27029d396
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/breast_cancer_cell_seg/datasets/breast-cancer-cell-seg_dataset.py
@@ -0,0 +1,29 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class BreastCancerCellSegDataset(BaseSegDataset):
+ """BreastCancerCellSegDataset dataset.
+
+ In segmentation map annotation for BreastCancerCellSegDataset,
+ ``reduce_zero_label`` is fixed to False. The ``img_suffix``
+ is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
+
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
+ Default to False.
+ """
+ METAINFO = dict(classes=('normal', 'breast cancer cell'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=False,
+ **kwargs)
diff --git a/projects/medical/2d_image/histopathology/breast_cancer_cell_seg/tools/prepare_dataset.py b/projects/medical/2d_image/histopathology/breast_cancer_cell_seg/tools/prepare_dataset.py
new file mode 100755
index 00000000000..775f2eed18a
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/breast_cancer_cell_seg/tools/prepare_dataset.py
@@ -0,0 +1,47 @@
+import glob
+import os
+
+import numpy as np
+from PIL import Image
+
+root_path = 'data/'
+img_suffix = '.tif'
+seg_map_suffix = '.TIF'
+save_img_suffix = '.png'
+save_seg_map_suffix = '.png'
+
+x_train = glob.glob(
+ os.path.join('data/Breast Cancer Cell Segmentation_datasets/Images/*' +
+ img_suffix))
+
+os.system('mkdir -p ' + root_path + 'images/train/')
+os.system('mkdir -p ' + root_path + 'masks/train/')
+
+D2_255_convert_dict = {0: 0, 255: 1}
+
+
+def convert_2d(img, convert_dict=D2_255_convert_dict):
+ arr_2d = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
+ for c, i in convert_dict.items():
+ arr_2d[img == c] = i
+ return arr_2d
+
+
+part_dir_dict = {0: 'train/'}
+for ith, part in enumerate([x_train]):
+ part_dir = part_dir_dict[ith]
+ for img in part:
+ basename = os.path.basename(img)
+ img_save_path = root_path + 'images/' + part_dir + basename.split(
+ '.')[0] + save_img_suffix
+ Image.open(img).save(img_save_path)
+ mask_path = root_path + 'Breast Cancer Cell Segmentation_datasets/Masks/' + '_'.join( # noqa
+ basename.split('_')[:-1]) + seg_map_suffix
+ label = np.array(Image.open(mask_path))
+
+ save_mask_path = root_path + 'masks/' + part_dir + basename.split(
+ '.')[0] + save_seg_map_suffix
+ assert len(label.shape) == 2 and 255 in label and 1 not in label
+ mask = convert_2d(label)
+ mask = Image.fromarray(mask.astype(np.uint8))
+ mask.save(save_mask_path)
diff --git a/projects/medical/2d_image/histopathology/conic2022_seg/README.md b/projects/medical/2d_image/histopathology/conic2022_seg/README.md
new file mode 100644
index 00000000000..1f55b44ed6d
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/conic2022_seg/README.md
@@ -0,0 +1,207 @@
+# CoNIC: Colon Nuclei Identification and Counting Challenge
+
+## Description
+
+This project supports **`CoNIC: Colon Nuclei Identification and Counting Challenge`**, which can be downloaded from [here](https://drive.google.com/drive/folders/1il9jG7uA4-ebQ_lNmXbbF2eOK9uNwheb).
+
+### Dataset Overview
+
+Nuclear segmentation, classification and quantification within Haematoxylin & Eosin stained histology images enables the extraction of interpretable cell-based features that can be used in downstream explainable models in computational pathology (CPath). To help drive forward research and innovation for automatic nuclei recognition in CPath, we organise the Colon Nuclei Identification and Counting (CoNIC) Challenge. The challenge requires researchers to develop algorithms that perform segmentation, classification and counting of 6 different types of nuclei within the current largest known publicly available nuclei-level dataset in CPath, containing around half a million labelled nuclei.
+
+### Task Information
+
+The CONIC challenge has 2 tasks:
+
+- Task 1: Nuclear segmentation and classification.
+
+The first task requires participants to segment nuclei within the tissue, while also classifying each nucleus into one of the following categories: epithelial, lymphocyte, plasma, eosinophil, neutrophil or connective tissue.
+
+- Task 2: Prediction of cellular composition.
+
+For the second task, we ask participants to predict how many nuclei of each class are present in each input image.
+
+The output of Task 1 can be directly used to perform Task 2, but these can be treated as independent tasks. Therefore, if it is preferred, prediction of cellular composition can be treated as a stand alone regression task.
+
+***NOTE:We only consider `Task 1` in the following sections.***
+
+### Original Statistic Information
+
+| Dataset name | Anatomical region | Task type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
+| -------------------------------------------------------- | ----------------- | ------------ | -------------- | ------------ | --------------------- | ---------------------- | ------------ | ------------------------------------------------------------------------------------------------------------ |
+| [CoNIC202](https://conic-challenge.grand-challenge.org/) | abdomen | segmentation | histopathology | 7 | 4981/-/- | yes/-/- | 2022 | [Attribution-NonCommercial-ShareAlike 4.0 International](https://creativecommons.org/licenses/by-nc-sa/4.0/) |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 4981 | 83.97 | - | - | - | - |
+| neutrophil | 1218 | 0.13 | - | - | - | - |
+| epithelial | 4256 | 10.31 | - | - | - | - |
+| lymphocyte | 4473 | 1.85 | - | - | - | - |
+| plasma | 3316 | 0.55 | - | - | - | - |
+| eosinophil | 1456 | 0.1 | - | - | - | - |
+| connective | 4613 | 3.08 | - | - | - | - |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![bac](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/histopathology/conic2022_seg/conic2022_seg_dataset.png)
+
+### Prerequisites
+
+- Python v3.8
+- PyTorch v1.10.0
+- pillow(PIL) v9.3.0
+- scikit-learn(sklearn) v1.2.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `conic2022_seg/` root directory, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset preparing
+
+- download dataset from [here](https://drive.google.com/drive/folders/1il9jG7uA4-ebQ_lNmXbbF2eOK9uNwheb/) and move data to path `'data/CoNIC_Challenge'`. The directory should be like:
+ ```shell
+ data/CoNIC_Challenge
+ ├── README.txt
+ ├── by-nc-sa.md
+ ├── counts.csv
+ ├── images.npy
+ ├── labels.npy
+ └── patch_info.csv
+ ```
+- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below.
+- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set can't be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── histopathology
+ │ │ │ │ ├── conic2022_seg
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── val.txt
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+```
+
+### Divided Dataset Information
+
+***Note: The table information below is divided by ourselves.***
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 3984 | 84.06 | 997 | 83.65 | - | - |
+| neutrophil | 956 | 0.12 | 262 | 0.13 | - | - |
+| epithelial | 3400 | 10.26 | 856 | 10.52 | - | - |
+| lymphocyte | 3567 | 1.83 | 906 | 1.96 | - | - |
+| plasma | 2645 | 0.55 | 671 | 0.56 | - | - |
+| eosinophil | 1154 | 0.1 | 302 | 0.1 | - | - |
+| connective | 3680 | 3.08 | 933 | 3.08 | - | - |
+
+### Training commands
+
+Train models on a single server with one GPU.
+
+```shell
+mim train mmseg ./configs/${CONFIG_FILE}
+```
+
+### Testing commands
+
+Test models on a single server with one GPU.
+
+```shell
+mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH}
+```
+
+
+
+## Organizers
+
+- Simon Graham (TIA, PathLAKE)
+- Mostafa Jahanifar (TIA, PathLAKE)
+- Dang Vu (TIA)
+- Giorgos Hadjigeorghiou (TIA, PathLAKE)
+- Thomas Leech (TIA, PathLAKE)
+- David Snead (UHCW, PathLAKE)
+- Shan Raza (TIA, PathLAKE)
+- Fayyaz Minhas (TIA, PathLAKE)
+- Nasir Rajpoot (TIA, PathLAKE)
+
+TIA: Tissue Image Analytics Centre, Department of Computer Science, University of Warwick, United Kingdom
+
+UHCW: Department of Pathology, University Hospitals Coventry and Warwickshire, United Kingdom
+
+PathLAKE: Pathology Image Data Lake for Analytics Knowledge & Education, , University Hospitals Coventry and Warwickshire, United Kingdom
+
+## Dataset Citation
+
+If this work is helpful for your research, please consider citing the below paper.
+
+```
+@inproceedings{graham2021lizard,
+ title={Lizard: A large-scale dataset for colonic nuclear instance segmentation and classification},
+ author={Graham, Simon and Jahanifar, Mostafa and Azam, Ayesha and Nimir, Mohammed and Tsang, Yee-Wah and Dodd, Katherine and Hero, Emily and Sahota, Harvir and Tank, Atisha and Benes, Ksenija and others},
+ booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
+ pages={684--693},
+ year={2021}
+}
+@article{graham2021conic,
+ title={Conic: Colon nuclei identification and counting challenge 2022},
+ author={Graham, Simon and Jahanifar, Mostafa and Vu, Quoc Dang and Hadjigeorghiou, Giorgos and Leech, Thomas and Snead, David and Raza, Shan E Ahmed and Minhas, Fayyaz and Rajpoot, Nasir},
+ journal={arXiv preprint arXiv:2111.14485},
+ year={2021}
+}
+```
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+
+ - [x] Basic docstrings & proper citation
+
+ - [x] A full README
+
+- [ ] Milestone 2: Indicates a successful model implementation.
+
+ - [ ] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+
+ - [ ] Unit tests
+
+ - [ ] Code polishing
+
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/histopathology/conic2022_seg/configs/conic2022-seg_512x512.py b/projects/medical/2d_image/histopathology/conic2022_seg/configs/conic2022-seg_512x512.py
new file mode 100644
index 00000000000..51b4e5782ab
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/conic2022_seg/configs/conic2022-seg_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'Conic2022SegDataset'
+data_root = 'data/'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='val.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/histopathology/conic2022_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_conic2022-512x512.py b/projects/medical/2d_image/histopathology/conic2022_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_conic2022-512x512.py
new file mode 100644
index 00000000000..3e0248c78ce
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/conic2022_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_conic2022-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ './conic2022-seg_512x512.py', 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.conic2022-seg_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=7),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/conic2022_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_conic2022-512x512.py b/projects/medical/2d_image/histopathology/conic2022_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_conic2022-512x512.py
new file mode 100644
index 00000000000..fd0e9d8d28b
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/conic2022_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_conic2022-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ './conic2022-seg_512x512.py', 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.conic2022-seg_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=7),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/conic2022_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_conic2022-512x512.py b/projects/medical/2d_image/histopathology/conic2022_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_conic2022-512x512.py
new file mode 100644
index 00000000000..bb667f14fd4
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/conic2022_seg/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_conic2022-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ './conic2022-seg_512x512.py', 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.conic2022-seg_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=7),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/conic2022_seg/conic2022_seg_dataset.png b/projects/medical/2d_image/histopathology/conic2022_seg/conic2022_seg_dataset.png
new file mode 100644
index 00000000000..65bb0bbe0a5
Binary files /dev/null and b/projects/medical/2d_image/histopathology/conic2022_seg/conic2022_seg_dataset.png differ
diff --git a/projects/medical/2d_image/histopathology/conic2022_seg/datasets/conic2022-seg_dataset.py b/projects/medical/2d_image/histopathology/conic2022_seg/datasets/conic2022-seg_dataset.py
new file mode 100644
index 00000000000..9af0958ab34
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/conic2022_seg/datasets/conic2022-seg_dataset.py
@@ -0,0 +1,29 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class Conic2022SegDataset(BaseSegDataset):
+ """Conic2022SegDataset dataset.
+
+ In segmentation map annotation for Conic2022SegDataset,
+ ``reduce_zero_label`` is fixed to False. The ``img_suffix``
+ is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
+
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ """
+ METAINFO = dict(
+ classes=('background', 'neutrophil', 'epithelial', 'lymphocyte',
+ 'plasma', 'eosinophil', 'connective'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=False,
+ **kwargs)
diff --git a/projects/medical/2d_image/histopathology/conic2022_seg/tools/prepare_dataset.py b/projects/medical/2d_image/histopathology/conic2022_seg/tools/prepare_dataset.py
new file mode 100755
index 00000000000..89cfb4aae24
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/conic2022_seg/tools/prepare_dataset.py
@@ -0,0 +1,65 @@
+import glob
+import os
+import shutil
+
+import numpy as np
+from PIL import Image
+
+img_save_root = 'data/'
+root_path = 'data/'
+img_suffix = '.png'
+seg_map_suffix = '.png'
+save_img_suffix = '.png'
+save_seg_map_suffix = '.png'
+
+label_set = set()
+
+
+def save_masks_from_npz(data, save_root, part='masks/'):
+ global label_set
+ num = data.shape[0]
+ for i in range(num):
+ # np_img = data[i, :, :, :]
+ np_mask = data[i, :, :, 1]
+ label_set = set.union(label_set, set(np.unique(np_mask)))
+ img = Image.fromarray(np_mask)
+ save_path = os.path.join(save_root, part, str(i) + save_seg_map_suffix)
+ img.save(save_path)
+
+
+def save_images_from_npz(data, save_root, part='images/'):
+ num = data.shape[0]
+ for i in range(num):
+ np_img = data[i, :, :, :]
+ img = Image.fromarray(np_img)
+ save_path = os.path.join(save_root, part, str(i) + save_img_suffix)
+ img.save(save_path)
+
+
+images_npy = np.load('data/CoNIC_Challenge/images.npy')
+labels_npy = np.load('data/CoNIC_Challenge/labels.npy')
+
+os.system('mkdir -p ' + img_save_root + 'images_ori')
+os.system('mkdir -p ' + img_save_root + 'labels')
+save_images_from_npz(images_npy, img_save_root, 'images_ori')
+save_masks_from_npz(labels_npy, img_save_root, 'labels')
+print(label_set)
+
+x_train = glob.glob(os.path.join('data/images_ori/*' + img_suffix))
+
+os.system('mkdir -p ' + root_path + 'images/train/')
+os.system('mkdir -p ' + root_path + 'masks/train/')
+
+part_dir_dict = {0: 'train/', 1: 'val/'}
+for ith, part in enumerate([x_train]):
+ part_dir = part_dir_dict[ith]
+ for img in part:
+ basename = os.path.basename(img)
+ shutil.copy(
+ img, root_path + 'images/' + part_dir + basename.split('.')[0] +
+ save_img_suffix)
+ mask_path = root_path + 'labels/' + basename.split(
+ '.')[0] + seg_map_suffix
+ save_mask_path = root_path + 'masks/' + part_dir + basename.split(
+ '.')[0] + save_seg_map_suffix
+ shutil.copy(mask_path, save_mask_path)
diff --git a/projects/medical/2d_image/histopathology/consep/README.md b/projects/medical/2d_image/histopathology/consep/README.md
new file mode 100644
index 00000000000..ca3d7aa1089
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/consep/README.md
@@ -0,0 +1,147 @@
+# Colorectal Nuclear Segmentation and Phenotypes (CoNSeP) Dataset
+
+## Description
+
+This project supports **`Colorectal Nuclear Segmentation and Phenotypes (CoNSeP) Dataset`**, which can be downloaded from [here](https://warwick.ac.uk/fac/cross_fac/tia/data/hovernet/).
+
+### Dataset Overview
+
+The CoNSeP (Colon Segmentation and Phenotyping) dataset consists of 41 H&E stained image tiles, each with a size of 1,000×1,000 pixels and a magnification of 40x. These images were extracted from 16 colorectal adenocarcinoma (CRA) whole slide images (WSI), each of which belonged to a separate patient and was scanned using an Omnyx VL120 scanner at the Pathology Department of the University Hospitals Coventry and Warwickshire NHS Trust, UK. This dataset was first used in paper named, "HoVer-Net: Simultaneous Segmentation and Classification of Nuclei in Multi-Tissue Histology Images".
+
+### Original Statistic Information
+
+| Dataset name | Anatomical region | Task type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
+| -------------------------------------------------------- | ----------------- | ------------ | -------------- | ------------ | --------------------- | ---------------------- | ------------ | ------- |
+| [CoNIC202](https://conic-challenge.grand-challenge.org/) | abdomen | segmentation | histopathology | 7 | 4981/-/- | yes/-/- | 2022 | - |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :-----------------------------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 27 | 83.61 | 14 | 80.4 | - | - |
+| other | 17 | 0.17 | 9 | 0.52 | - | - |
+| inflammatory | 25 | 2.66 | 14 | 2.14 | - | - |
+| healthy epithelial | 3 | 1.47 | 2 | 1.58 | - | - |
+| dysplastic/malignant epithelial | 10 | 7.17 | 8 | 9.16 | - | - |
+| fibroblast | 23 | 3.84 | 14 | 4.63 | - | - |
+| muscle | 8 | 1.05 | 3 | 1.42 | - | - |
+| endothelial | 7 | 0.02 | 4 | 0.15 | - | - |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![bac](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/histopathology/consep/consep_dataset.png)
+
+### Prerequisites
+
+- Python v3.8
+- PyTorch v1.10.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `conic2022_seg/` root directory, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset preparing
+
+- download dataset from [here](https://opendatalab.com/CoNSeP) and decompress data to path `'data/'`.
+- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below.
+- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set can't be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── histopathology
+ │ │ │ │ ├── consep
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── val.txt
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+```
+
+### Training commands
+
+Train models on a single server with one GPU.
+
+```shell
+mim train mmseg ./configs/${CONFIG_FILE}
+```
+
+### Testing commands
+
+Test models on a single server with one GPU.
+
+```shell
+mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH}
+```
+
+
+
+## Dataset Citation
+
+If this work is helpful for your research, please consider citing the below paper.
+
+```
+@article{graham2019hover,
+ title={Hover-net: Simultaneous segmentation and classification of nuclei in multi-tissue histology images},
+ author={Graham, Simon and Vu, Quoc Dang and Raza, Shan E Ahmed and Azam, Ayesha and Tsang, Yee Wah and Kwak, Jin Tae and Rajpoot, Nasir},
+ journal={Medical Image Analysis},
+ volume={58},
+ pages={101563},
+ year={2019},
+ publisher={Elsevier}
+}
+```
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+
+ - [x] Basic docstrings & proper citation
+
+ - [x] Test-time correctness
+
+ - [x] A full README
+
+- [ ] Milestone 2: Indicates a successful model implementation.
+
+ - [ ] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+
+ - [ ] Unit tests
+
+ - [ ] Code polishing
+
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/histopathology/consep/configs/consep_512x512.py b/projects/medical/2d_image/histopathology/consep/configs/consep_512x512.py
new file mode 100644
index 00000000000..0d9b8948b01
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/consep/configs/consep_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'ConsepDataset'
+data_root = 'data/'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='val.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/histopathology/consep/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_consep-512x512.py b/projects/medical/2d_image/histopathology/consep/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_consep-512x512.py
new file mode 100644
index 00000000000..cbcf5db775b
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/consep/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_consep-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ './consep_512x512.py', 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.consep_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=8),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/consep/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_consep-512x512.py b/projects/medical/2d_image/histopathology/consep/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_consep-512x512.py
new file mode 100644
index 00000000000..b374566e6ef
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/consep/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_consep-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ './consep_512x512.py', 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.consep_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=8),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/consep/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_consep-512x512.py b/projects/medical/2d_image/histopathology/consep/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_consep-512x512.py
new file mode 100644
index 00000000000..35bdaa34c84
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/consep/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_consep-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ './consep_512x512.py', 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.consep_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=8),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/consep/datasets/consep_dataset.py b/projects/medical/2d_image/histopathology/consep/datasets/consep_dataset.py
new file mode 100644
index 00000000000..ceb2b3ab25b
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/consep/datasets/consep_dataset.py
@@ -0,0 +1,30 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class ConsepDataset(BaseSegDataset):
+ """ConsepDataset dataset.
+
+ In segmentation map annotation for ConsepDataset,
+ ``reduce_zero_label`` is fixed to False. The ``img_suffix``
+ is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
+
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ """
+ METAINFO = dict(
+ classes=('background', 'other', 'inflammatory', 'healthy epithelial',
+ 'dysplastic/malignant epithelial', 'fibroblast', 'muscle',
+ 'endothelial'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=False,
+ **kwargs)
diff --git a/projects/medical/2d_image/histopathology/consep/tools/prepare_dataset.py b/projects/medical/2d_image/histopathology/consep/tools/prepare_dataset.py
new file mode 100755
index 00000000000..83a2e18ce10
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/consep/tools/prepare_dataset.py
@@ -0,0 +1,54 @@
+import glob
+import os
+import shutil
+
+import numpy as np
+from PIL import Image
+from scipy.io import loadmat
+
+root_path = 'data/'
+img_suffix = '.png'
+seg_map_suffix = '.mat'
+save_img_suffix = '.png'
+save_seg_map_suffix = '.png'
+
+x_train = glob.glob(os.path.join('data/CoNSeP/Train/Images/*' + img_suffix))
+x_test = glob.glob(os.path.join('data/CoNSeP/Test/Images/*' + img_suffix))
+
+os.system('mkdir -p ' + root_path + 'images/train/')
+os.system('mkdir -p ' + root_path + 'images/val/')
+os.system('mkdir -p ' + root_path + 'masks/train/')
+os.system('mkdir -p ' + root_path + 'masks/val/')
+D2_255_convert_dict = {0: 0, 255: 1}
+
+
+def convert_2d(img, convert_dict=D2_255_convert_dict):
+ arr_2d = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
+ for c, i in convert_dict.items():
+ arr_2d[img == c] = i
+ return arr_2d
+
+
+part_dir_dict = {0: 'CoNSeP/Train/', 1: 'CoNSeP/Test/'}
+save_dir_dict = {0: 'train/', 1: 'val/'}
+for ith, part in enumerate([x_train, x_test]):
+ part_dir = part_dir_dict[ith]
+ for img in part:
+ basename = os.path.basename(img)
+ shutil.copy(
+ img, root_path + 'images/' + save_dir_dict[ith] +
+ basename.split('.')[0] + save_img_suffix)
+
+ mask_path = root_path + part_dir + 'Labels/' + basename.split(
+ '.')[0] + seg_map_suffix
+ label_ = loadmat(mask_path)
+ label = label_['inst_map']
+ label_type = label_['inst_type']
+ label_dict = {i + 1: int(val) for i, val in enumerate(label_type)}
+
+ save_mask_path = root_path + 'masks/' + save_dir_dict[
+ ith] + basename.split('.')[0] + save_seg_map_suffix
+
+ res = convert_2d(label, convert_dict=label_dict)
+ res = Image.fromarray(res.astype(np.uint8))
+ res.save(save_mask_path)
diff --git a/projects/medical/2d_image/histopathology/fusc2021/README.md b/projects/medical/2d_image/histopathology/fusc2021/README.md
new file mode 100644
index 00000000000..8130d593503
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/fusc2021/README.md
@@ -0,0 +1,136 @@
+# Foot Ulcer Segmentation Challenge 2021 (FUSC 2021)
+
+## Description
+
+This project supports **`Foot Ulcer Segmentation Challenge 2021 (FUSC 2021) `**, which can be downloaded from [here](https://fusc.grand-challenge.org/).
+
+### Dataset Overview
+
+This chronic wound dataset was collected over 2 years from October 2019 to April 2021 at the center and contains 1,210 foot ulcer images taken from 889 patients during multiple clinical visits. The raw images were taken by Canon SX 620 HS digital camera and iPad Pro under uncontrolled illumination conditions,
+with various backgrounds. The images (shown in Figure 1) are randomly split into 3 subsets: a training set with 810 images, a validation set with 200 images, and a testing set with 200 images. Of course, the annotations of the testing set are kept private. The data collected were de-identified and in accordance with relevant guidelines and regulations and the patient’s informed consent is waived by the institutional review board of the University of Wisconsin-Milwaukee.
+
+### Information Statistics
+
+| Dataset Name | Anatomical Region | Task Type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
+| --------------------------------------------- | ----------------- | ------------ | -------------- | ------------ | --------------------- | ---------------------- | ------------ | ------------------------------------------------------------- |
+| [fusc2021](https://fusc.grand-challenge.org/) | lower limb | segmentation | histopathology | 2 | 810/200/200 | yes/yes/no | 2021 | [CC0 1.0](https://creativecommons.org/publicdomain/zero/1.0/) |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 810 | 98.71 | 200 | 98.78 | - | - |
+| wound | 791 | 1.29 | 195 | 1.22 | - | - |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![fusc2021](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/histopathology/fusc2021/fusc2021_dataset.png?raw=true)
+
+### Dataset Citation
+
+```
+@article{s41598-020-78799-w,
+ title={Fully automatic wound segmentation with deep convolutional neural networks},
+ author={Chuanbo Wang and D. M. Anisuzzaman and Victor Williamson and Mrinal Kanti Dhar and Behrouz Rostami and Jeffrey Niezgoda and Sandeep Gopalakrishnan and Zeyun Yu},
+ journal={Scientific Reports},
+ volume={10},
+ number={1},
+ pages={21897},
+ year={2020}
+}
+```
+
+### Prerequisites
+
+- Python v3.8
+- PyTorch v1.10.0
+- pillow(PIL) v9.3.0
+- scikit-learn(sklearn) v1.2.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `fusc2021/` root directory, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset Preparing
+
+- download dataset from [here](https://fusc.grand-challenge.org/) and decompress data to path `'data/'`.
+- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below.
+- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set cannot be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── histopathology
+ │ │ │ │ ├── fusc2021
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── val.txt
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+```
+
+### Training commands
+
+To train models on a single server with one GPU. (default)
+
+```shell
+mim train mmseg ./configs/${CONFIG_FILE}
+```
+
+### Testing commands
+
+To test models on a single server with one GPU. (default)
+
+```shell
+mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH}
+```
+
+
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+ - [x] Basic docstrings & proper citation
+ - [ ] Test-time correctness
+ - [x] A full README
+
+- [ ] Milestone 2: Indicates a successful model implementation.
+
+ - [ ] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+ - [ ] Unit tests
+ - [ ] Code polishing
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/histopathology/fusc2021/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_fusc2021-512x512.py b/projects/medical/2d_image/histopathology/fusc2021/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_fusc2021-512x512.py
new file mode 100644
index 00000000000..c3f42751124
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/fusc2021/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_fusc2021-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './fusc2021_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.fusc2021_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/fusc2021/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_fusc2021-512x512.py b/projects/medical/2d_image/histopathology/fusc2021/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_fusc2021-512x512.py
new file mode 100644
index 00000000000..ed870303fff
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/fusc2021/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_fusc2021-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './fusc2021_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.fusc2021_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/fusc2021/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_fusc2021-512x512.py b/projects/medical/2d_image/histopathology/fusc2021/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_fusc2021-512x512.py
new file mode 100644
index 00000000000..cbc09ae6cdd
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/fusc2021/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_fusc2021-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './fusc2021_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.fusc2021_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/fusc2021/configs/fcn-unet-s5-d16_unet_1xb16-0.01lr-sigmoid-20k_fusc2021-512x512.py b/projects/medical/2d_image/histopathology/fusc2021/configs/fcn-unet-s5-d16_unet_1xb16-0.01lr-sigmoid-20k_fusc2021-512x512.py
new file mode 100644
index 00000000000..f1477ee725e
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/fusc2021/configs/fcn-unet-s5-d16_unet_1xb16-0.01lr-sigmoid-20k_fusc2021-512x512.py
@@ -0,0 +1,18 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './fusc2021_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.fusc2021_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(
+ num_classes=2, loss_decode=dict(use_sigmoid=True), out_channels=1),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/fusc2021/configs/fusc2021_512x512.py b/projects/medical/2d_image/histopathology/fusc2021/configs/fusc2021_512x512.py
new file mode 100644
index 00000000000..e650474cea8
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/fusc2021/configs/fusc2021_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'FUSC2021Dataset'
+data_root = 'data/'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='val.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/histopathology/fusc2021/datasets/fusc2021_dataset.py b/projects/medical/2d_image/histopathology/fusc2021/datasets/fusc2021_dataset.py
new file mode 100644
index 00000000000..d331ac8c3a2
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/fusc2021/datasets/fusc2021_dataset.py
@@ -0,0 +1,30 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class FUSC2021Dataset(BaseSegDataset):
+ """FUSC2021Dataset dataset.
+
+ In segmentation map annotation for FUSC2021Dataset, 0 stands for background
+ , which is included in 2 categories. ``reduce_zero_label``
+ is fixed to False. The ``img_suffix`` is fixed to '.png' and
+ ``seg_map_suffix`` is fixed to '.png'.
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
+ Default to False..
+ """
+ METAINFO = dict(classes=('background', 'wound'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
diff --git a/projects/medical/2d_image/histopathology/fusc2021/tools/prepare_dataset.py b/projects/medical/2d_image/histopathology/fusc2021/tools/prepare_dataset.py
new file mode 100644
index 00000000000..8f2de3daa95
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/fusc2021/tools/prepare_dataset.py
@@ -0,0 +1,114 @@
+import glob
+import os
+
+import numpy as np
+from PIL import Image
+
+root_path = 'data/'
+img_suffix = '.png'
+seg_map_suffix = '.png'
+save_img_suffix = '.png'
+save_seg_map_suffix = '.png'
+src_img_train_dir = os.path.join(
+ root_path, 'wound-segmentation/data/' +
+ 'Foot Ulcer Segmentation Challenge/train/images')
+src_img_val_dir = os.path.join(
+ root_path, 'wound-segmentation/data/' +
+ 'Foot Ulcer Segmentation Challenge/validation/images')
+src_img_test_dir = os.path.join(
+ root_path, 'wound-segmentation/data/' +
+ 'Foot Ulcer Segmentation Challenge/test/images')
+src_mask_train_dir = os.path.join(
+ root_path, 'wound-segmentation/data/' +
+ 'Foot Ulcer Segmentation Challenge/train/labels')
+src_mask_val_dir = os.path.join(
+ root_path, 'wound-segmentation/data/' +
+ 'Foot Ulcer Segmentation Challenge/validation/labels')
+
+tgt_img_train_dir = os.path.join(root_path, 'images/train/')
+tgt_mask_train_dir = os.path.join(root_path, 'masks/train/')
+tgt_img_val_dir = os.path.join(root_path, 'images/val/')
+tgt_mask_val_dir = os.path.join(root_path, 'masks/val/')
+tgt_img_test_dir = os.path.join(root_path, 'images/test/')
+os.system('mkdir -p ' + tgt_img_train_dir)
+os.system('mkdir -p ' + tgt_img_val_dir)
+os.system('mkdir -p ' + tgt_img_test_dir)
+os.system('mkdir -p ' + tgt_mask_train_dir)
+os.system('mkdir -p ' + tgt_mask_val_dir)
+
+
+def filter_suffix_recursive(src_dir, suffix):
+ # filter out file names and paths in source directory
+ suffix = '.' + suffix if '.' not in suffix else suffix
+ file_paths = glob.glob(
+ os.path.join(src_dir, '**', '*' + suffix), recursive=True)
+ file_names = [_.split('/')[-1] for _ in file_paths]
+ return sorted(file_paths), sorted(file_names)
+
+
+def convert_label(img, convert_dict):
+ arr = np.zeros_like(img, dtype=np.uint8)
+ for c, i in convert_dict.items():
+ arr[img == c] = i
+ return arr
+
+
+def convert_pics_into_pngs(src_dir, tgt_dir, suffix, convert='RGB'):
+ if not os.path.exists(tgt_dir):
+ os.makedirs(tgt_dir)
+
+ src_paths, src_names = filter_suffix_recursive(src_dir, suffix=suffix)
+
+ for i, (src_name, src_path) in enumerate(zip(src_names, src_paths)):
+ tgt_name = src_name.replace(suffix, save_img_suffix)
+ tgt_path = os.path.join(tgt_dir, tgt_name)
+ num = len(src_paths)
+ img = np.array(Image.open(src_path))
+ if len(img.shape) == 2:
+ pil = Image.fromarray(img).convert(convert)
+ elif len(img.shape) == 3:
+ pil = Image.fromarray(img)
+ else:
+ raise ValueError('Input image not 2D/3D: ', img.shape)
+
+ pil.save(tgt_path)
+ print(f'processed {i+1}/{num}.')
+
+
+def convert_label_pics_into_pngs(src_dir,
+ tgt_dir,
+ suffix,
+ convert_dict={
+ 0: 0,
+ 255: 1
+ }):
+ if not os.path.exists(tgt_dir):
+ os.makedirs(tgt_dir)
+
+ src_paths, src_names = filter_suffix_recursive(src_dir, suffix=suffix)
+ num = len(src_paths)
+ for i, (src_name, src_path) in enumerate(zip(src_names, src_paths)):
+ tgt_name = src_name.replace(suffix, save_seg_map_suffix)
+ tgt_path = os.path.join(tgt_dir, tgt_name)
+
+ img = np.array(Image.open(src_path).convert('L'))
+ img = convert_label(img, convert_dict)
+ Image.fromarray(img).save(tgt_path)
+ print(f'processed {i+1}/{num}.')
+
+
+if __name__ == '__main__':
+
+ convert_pics_into_pngs(
+ src_img_train_dir, tgt_img_train_dir, suffix=img_suffix)
+
+ convert_pics_into_pngs(src_img_val_dir, tgt_img_val_dir, suffix=img_suffix)
+
+ convert_pics_into_pngs(
+ src_img_test_dir, tgt_img_test_dir, suffix=img_suffix)
+
+ convert_label_pics_into_pngs(
+ src_mask_train_dir, tgt_mask_train_dir, suffix=seg_map_suffix)
+
+ convert_label_pics_into_pngs(
+ src_mask_val_dir, tgt_mask_val_dir, suffix=seg_map_suffix)
diff --git a/projects/medical/2d_image/histopathology/pannuke/README.md b/projects/medical/2d_image/histopathology/pannuke/README.md
new file mode 100644
index 00000000000..e0cade7536d
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/pannuke/README.md
@@ -0,0 +1,146 @@
+# Pan-Cancer Histology Dataset for Nuclei Instance Segmentation and Classification (PanNuke)
+
+## Description
+
+This project supports **`Pan-Cancer Histology Dataset for Nuclei Instance Segmentation and Classification (PanNuke)`**, which can be downloaded from [here](https://academictorrents.com/details/99f2c7b57b95500711e33f2ee4d14c9fd7c7366c).
+
+### Dataset Overview
+
+Semi automatically generated nuclei instance segmentation and classification dataset with exhaustive nuclei labels across 19 different tissue types. The dataset consists of 481 visual fields, of which 312 are randomly sampled from more than 20K whole slide images at different magnifications, from multiple data sources. In total the dataset contains 205,343 labeled nuclei, each with an instance segmentation mask. Models trained on pannuke can aid in whole slide image tissue type segmentation, and generalise to new tissues. PanNuke demonstrates one of the first successfully semi-automatically generated datasets.
+
+### Statistic Information
+
+| Dataset Name | Anatomical Region | Task Type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
+| ---------------------------------------------------------------------------------------- | ----------------- | ------------ | -------------- | ------------ | --------------------- | ---------------------- | ------------ | --------------------------------------------------------------- |
+| [Pannuke](https://academictorrents.com/details/99f2c7b57b95500711e33f2ee4d14c9fd7c7366c) | full_body | segmentation | histopathology | 6 | 7901/-/- | yes/-/- | 2019 | [CC-BY-NC 4.0](https://creativecommons.org/licenses/by-sa/4.0/) |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :-----------------------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 7901 | 83.32 | - | - | - | - |
+| neoplastic | 4190 | 8.64 | - | - | - | - |
+| non-neoplastic epithelial | 4126 | 1.77 | - | - | - | - |
+| inflammatory | 6137 | 3.73 | - | - | - | - |
+| connective | 232 | 0.07 | - | - | - | - |
+| dead | 1528 | 2.47 | - | - | - | - |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![pannuke](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/histopathology/pannuke/pannuke_dataset.png?raw=true)
+
+### Dataset Citation
+
+```
+@inproceedings{gamper2019pannuke,
+ title={PanNuke: an open pan-cancer histology dataset for nuclei instance segmentation and classification},
+ author={Gamper, Jevgenij and Koohbanani, Navid Alemi and Benet, Ksenija and Khuram, Ali and Rajpoot, Nasir},
+ booktitle={European Congress on Digital Pathology},
+ pages={11--19},
+ year={2019},
+}
+```
+
+### Prerequisites
+
+- Python v3.8
+- PyTorch v1.10.0
+- pillow(PIL) v9.3.0 9.3.0
+- scikit-learn(sklearn) v1.2.0 1.2.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `pannuke/` root directory, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset Preparing
+
+- download dataset from [here](https://academictorrents.com/details/99f2c7b57b95500711e33f2ee4d14c9fd7c7366c) and decompress data to path `'data/'`.
+- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below.
+- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set cannot be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── histopathology
+ │ │ │ │ ├── pannuke
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── val.txt
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+```
+
+### Divided Dataset Information
+
+***Note: The table information below is divided by ourselves.***
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :-----------------------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 6320 | 83.38 | 1581 | 83.1 | - | - |
+| neoplastic | 3339 | 8.55 | 851 | 9.0 | - | - |
+| non-neoplastic epithelial | 3293 | 1.77 | 833 | 1.76 | - | - |
+| inflammatory | 4914 | 3.72 | 1223 | 3.76 | - | - |
+| connective | 170 | 0.06 | 62 | 0.09 | - | - |
+| dead | 1235 | 2.51 | 293 | 2.29 | - | - |
+
+### Training commands
+
+To train models on a single server with one GPU. (default)
+
+```shell
+mim train mmseg ./configs/${CONFIG_FILE}
+```
+
+### Testing commands
+
+To test models on a single server with one GPU. (default)
+
+```shell
+mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH}
+```
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+ - [x] Basic docstrings & proper citation
+ - [ ] Test-time correctness
+ - [x] A full README
+
+- [ ] Milestone 2: Indicates a successful model implementation.
+
+ - [ ] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+ - [ ] Unit tests
+ - [ ] Code polishing
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_bactteria-detection-512x512.py b/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_bactteria-detection-512x512.py
new file mode 100644
index 00000000000..92584e9a684
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_bactteria-detection-512x512.py
@@ -0,0 +1,18 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ './bactteria-detection_512x512.py', 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.bactteria-detection_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(
+ num_classes=2, loss_decode=dict(use_sigmoid=True), out_channels=1),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_pannuke-512x512.py b/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_pannuke-512x512.py
new file mode 100644
index 00000000000..042a08ce008
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_pannuke-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './pannuke_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.pannuke_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=6),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_pannuke-512x512.py b/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_pannuke-512x512.py
new file mode 100644
index 00000000000..e92514c9132
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_pannuke-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './pannuke_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.pannuke_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=6),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_pannuke-512x512.py b/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_pannuke-512x512.py
new file mode 100644
index 00000000000..a9403c849fa
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/pannuke/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_pannuke-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './pannuke_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.pannuke_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=6),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/pannuke/configs/pannuke_512x512.py b/projects/medical/2d_image/histopathology/pannuke/configs/pannuke_512x512.py
new file mode 100644
index 00000000000..316ac1ac443
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/pannuke/configs/pannuke_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'PanNukeDataset'
+data_root = 'data/'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='val.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/histopathology/pannuke/datasets/pannuke_dataset.py b/projects/medical/2d_image/histopathology/pannuke/datasets/pannuke_dataset.py
new file mode 100644
index 00000000000..4d3c687ff31
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/pannuke/datasets/pannuke_dataset.py
@@ -0,0 +1,33 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class PanNukeDataset(BaseSegDataset):
+ """PanNukeDataset dataset.
+
+ In segmentation map annotation for PanNukeDataset,
+ 0 stands for background, which is included in 6 categories.
+ ``reduce_zero_label`` is fixed to False. The ``img_suffix``
+ is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
+
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
+ Default to False.
+ """
+ METAINFO = dict(
+ classes=('background', 'neoplastic', 'non-neoplastic epithelial',
+ 'inflammatory', 'connective', 'dead'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
diff --git a/projects/medical/2d_image/histopathology/pannuke/tools/prepare_dataset.py b/projects/medical/2d_image/histopathology/pannuke/tools/prepare_dataset.py
new file mode 100644
index 00000000000..7213b181f40
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/pannuke/tools/prepare_dataset.py
@@ -0,0 +1,49 @@
+import os
+
+import numpy as np
+from PIL import Image
+
+root_path = 'data/'
+
+tgt_img_dir = os.path.join(root_path, 'images/train')
+tgt_mask_dir = os.path.join(root_path, 'masks/train')
+os.system('mkdir -p ' + tgt_img_dir)
+os.system('mkdir -p ' + tgt_mask_dir)
+
+fold_img_paths = sorted([
+ os.path.join(root_path, 'pannuke/Fold 1/images/fold1/images.npy'),
+ os.path.join(root_path, 'pannuke/Fold 2/images/fold2/images.npy'),
+ os.path.join(root_path, 'pannuke/Fold 3/images/fold3/images.npy')
+])
+
+fold_mask_paths = sorted([
+ os.path.join(root_path, 'pannuke/Fold 1/masks/fold1/masks.npy'),
+ os.path.join(root_path, 'pannuke/Fold 2/masks/fold2/masks.npy'),
+ os.path.join(root_path, 'pannuke/Fold 3/masks/fold3/masks.npy')
+])
+
+for n, (img_path,
+ mask_path) in enumerate(zip(fold_img_paths, fold_mask_paths)):
+ fold_name = str(n + 1)
+ imgs = np.load(img_path)
+ masks = np.load(mask_path)
+
+ for i in range(imgs.shape[0]):
+ img = np.uint8(imgs[i])
+ mask_multichannel = np.minimum(np.uint8(masks[i]), 1)
+ mask = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
+ for j in range(mask_multichannel.shape[-1]):
+ factor = (j + 1) % mask_multichannel.shape[-1]
+ # convert [0,1,2,3,4,5] to [1,2,3,4,5,0],
+ # with the last label being background
+ mask[mask_multichannel[..., j] == 1] = factor
+
+ file_name = 'fold' + fold_name + '_' + str(i).rjust(4, '0') + '.png'
+ print('Processing: ', file_name)
+ tgt_img_path = os.path.join(tgt_img_dir, file_name)
+ tgt_mask_path = os.path.join(tgt_mask_dir, file_name)
+ Image.fromarray(img).save(tgt_img_path)
+ Image.fromarray(mask).save(tgt_mask_path)
+
+ del imgs
+ del masks
diff --git a/projects/medical/2d_image/histopathology/pcam/README.md b/projects/medical/2d_image/histopathology/pcam/README.md
new file mode 100644
index 00000000000..5a8094950c5
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/pcam/README.md
@@ -0,0 +1,153 @@
+# PCam (PatchCamelyon)
+
+## Description
+
+This project supports **`Patch Camelyon (PCam) `**, which can be downloaded from [here](https://opendatalab.com/PCam).
+
+### Dataset Overview
+
+PatchCamelyon is an image classification dataset. It consists of 327680 color images (96 x 96px) extracted from histopathologic scans of lymph node sections. Each image is annotated with a binary label indicating presence of metastatic tissue. PCam provides a new benchmark for machine learning models: bigger than CIFAR10, smaller than ImageNet, trainable on a single GPU.
+
+### Statistic Information
+
+| Dataset Name | Anatomical Region | Task Type | Modality | Num. Classes | Train/Val/Test images | Train/Val/Test Labeled | Release Date | License |
+| ------------------------------------ | ----------------- | ------------ | -------------- | ------------ | --------------------- | ---------------------- | ------------ | ------------------------------------------------------------- |
+| [Pcam](https://opendatalab.com/PCam) | throax | segmentation | histopathology | 2 | 327680/-/- | yes/-/- | 2018 | [CC0 1.0](https://creativecommons.org/publicdomain/zero/1.0/) |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :---------------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 214849 | 63.77 | - | - | - | - |
+| metastatic tissue | 131832 | 36.22 | - | - | - | - |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![pcam](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/histopathology/pcam/pcam_dataset.png?raw=true)
+
+### Dataset Citation
+
+```
+@inproceedings{veeling2018rotation,
+ title={Rotation equivariant CNNs for digital pathology},
+ author={Veeling, Bastiaan S and Linmans, Jasper and Winkens, Jim and Cohen, Taco and Welling, Max},
+ booktitle={International Conference on Medical image computing and computer-assisted intervention},
+ pages={210--218},
+ year={2018},
+}
+```
+
+### Prerequisites
+
+- Python v3.8
+- PyTorch v1.10.0
+- pillow(PIL) v9.3.0 9.3.0
+- scikit-learn(sklearn) v1.2.0 1.2.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `pcam/` root directory, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset Preparing
+
+- download dataset from [here](https://opendatalab.com/PCam) and decompress data to path `'data/'`.
+- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below.
+- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set cannot be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
+
+```shell
+mkdir data & cd data
+pip install opendatalab
+odl get PCam
+mv ./PCam/raw/pcamv1 ./
+rm -rf PCam
+cd ..
+python tools/prepare_dataset.py
+python ../../tools/split_seg_dataset.py
+```
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── histopathology
+ │ │ │ │ ├── pcam
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── val.txt
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+```
+
+### Divided Dataset Information
+
+***Note: The table information below is divided by ourselves.***
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :---------------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 171948 | 63.82 | 42901 | 63.6 | - | - |
+| metastatic tissue | 105371 | 36.18 | 26461 | 36.4 | - | - |
+
+### Training commands
+
+To train models on a single server with one GPU. (default)
+
+```shell
+mim train mmseg ./configs/${CONFIG_FILE}
+```
+
+### Testing commands
+
+To test models on a single server with one GPU. (default)
+
+```shell
+mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH}
+```
+
+
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+ - [x] Basic docstrings & proper citation
+ - [ ] Test-time correctness
+ - [x] A full README
+
+- [ ] Milestone 2: Indicates a successful model implementation.
+
+ - [ ] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+ - [ ] Unit tests
+ - [ ] Code polishing
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/histopathology/pcam/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_pcam-512x512.py b/projects/medical/2d_image/histopathology/pcam/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_pcam-512x512.py
new file mode 100644
index 00000000000..20601f1ea56
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/pcam/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_pcam-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './pcam_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.pcam_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/pcam/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_pcam-512x512.py b/projects/medical/2d_image/histopathology/pcam/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_pcam-512x512.py
new file mode 100644
index 00000000000..c057535409f
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/pcam/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_pcam-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './pcam_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.pcam_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/pcam/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_pcam-512x512.py b/projects/medical/2d_image/histopathology/pcam/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_pcam-512x512.py
new file mode 100644
index 00000000000..4c1d5fe421c
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/pcam/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_pcam-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './pcam_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.pcam_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/pcam/configs/fcn-unet-s5-d16_unet_1xb16-0.01lr-sigmoid-20k_pcam-512x512.py b/projects/medical/2d_image/histopathology/pcam/configs/fcn-unet-s5-d16_unet_1xb16-0.01lr-sigmoid-20k_pcam-512x512.py
new file mode 100644
index 00000000000..25e3734795c
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/pcam/configs/fcn-unet-s5-d16_unet_1xb16-0.01lr-sigmoid-20k_pcam-512x512.py
@@ -0,0 +1,18 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './pcam_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.pcam_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(
+ num_classes=2, loss_decode=dict(use_sigmoid=True), out_channels=1),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/histopathology/pcam/configs/pcam_512x512.py b/projects/medical/2d_image/histopathology/pcam/configs/pcam_512x512.py
new file mode 100644
index 00000000000..04efc23eb51
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/pcam/configs/pcam_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'PCamDataset'
+data_root = 'data/'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='val.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/histopathology/pcam/datasets/pcam_dataset.py b/projects/medical/2d_image/histopathology/pcam/datasets/pcam_dataset.py
new file mode 100644
index 00000000000..1c27de543ab
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/pcam/datasets/pcam_dataset.py
@@ -0,0 +1,31 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class PCamDataset(BaseSegDataset):
+ """PCamDataset dataset.
+
+ In segmentation map annotation for PCamDataset,
+ 0 stands for background, which is included in 2 categories.
+ ``reduce_zero_label`` is fixed to False. The ``img_suffix``
+ is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
+
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
+ Default to False.
+ """
+ METAINFO = dict(classes=('background', 'metastatic tissue'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
diff --git a/projects/medical/2d_image/histopathology/pcam/tools/prepare_dataset.py b/projects/medical/2d_image/histopathology/pcam/tools/prepare_dataset.py
new file mode 100644
index 00000000000..75038e6fb49
--- /dev/null
+++ b/projects/medical/2d_image/histopathology/pcam/tools/prepare_dataset.py
@@ -0,0 +1,49 @@
+import os
+
+import h5py
+import numpy as np
+from PIL import Image
+
+root_path = 'data/'
+
+tgt_img_train_dir = os.path.join(root_path, 'images/train/')
+tgt_mask_train_dir = os.path.join(root_path, 'masks/train/')
+tgt_img_val_dir = os.path.join(root_path, 'images/val/')
+tgt_img_test_dir = os.path.join(root_path, 'images/test/')
+
+os.system('mkdir -p ' + tgt_img_train_dir)
+os.system('mkdir -p ' + tgt_mask_train_dir)
+os.system('mkdir -p ' + tgt_img_val_dir)
+os.system('mkdir -p ' + tgt_img_test_dir)
+
+
+def extract_pics_from_h5(h5_path, h5_key, save_dir):
+ f = h5py.File(h5_path, 'r')
+ for i, img in enumerate(f[h5_key]):
+ img = img.astype(np.uint8).squeeze()
+ img = Image.fromarray(img)
+ save_image_path = os.path.join(save_dir, str(i).zfill(8) + '.png')
+ img.save(save_image_path)
+
+
+if __name__ == '__main__':
+
+ extract_pics_from_h5(
+ 'data/pcamv1/camelyonpatch_level_2_split_train_x.h5',
+ h5_key='x',
+ save_dir=tgt_img_train_dir)
+
+ extract_pics_from_h5(
+ 'data/pcamv1/camelyonpatch_level_2_split_valid_x.h5',
+ h5_key='x',
+ save_dir=tgt_img_val_dir)
+
+ extract_pics_from_h5(
+ 'data/pcamv1/camelyonpatch_level_2_split_test_x.h5',
+ h5_key='x',
+ save_dir=tgt_img_test_dir)
+
+ extract_pics_from_h5(
+ 'data/pcamv1/camelyonpatch_level_2_split_train_mask.h5',
+ h5_key='mask',
+ save_dir=tgt_mask_train_dir)
diff --git a/projects/medical/2d_image/infrared_reflectance_imaging/ravir/README.md b/projects/medical/2d_image/infrared_reflectance_imaging/ravir/README.md
new file mode 100644
index 00000000000..ca95921ba36
--- /dev/null
+++ b/projects/medical/2d_image/infrared_reflectance_imaging/ravir/README.md
@@ -0,0 +1,167 @@
+# RAVIR: A Dataset and Methodology for the Semantic Segmentation and Quantitative Analysis of Retinal Arteries and Veins in Infrared Reflectance Imaging
+
+## Description
+
+This project support **`RAVIR: A Dataset and Methodology for the Semantic Segmentation and Quantitative Analysis of Retinal Arteries and Veins in Infrared Reflectance Imaging`**, and the dataset used in this project can be downloaded from [here](https://ravir.grand-challenge.org/).
+
+### Dataset Overview
+
+The retinal vasculature provides important clues in the diagnosis and monitoring of systemic diseases including hypertension and diabetes. The microvascular system is of primary involvement in such conditions, and the retina is the only anatomical site where the microvasculature can be directly observed. The objective assessment of retinal vessels has long been considered a surrogate biomarker for systemic vascular diseases, and with recent advancements in retinal imaging and computer vision technologies, this topic has become the subject of renewed attention. In this paper, we present a novel dataset, dubbed RAVIR, for the semantic segmentation of Retinal Arteries and Veins in Infrared Reflectance (IR) imaging. It enables the creation of deep learning-based models that distinguish extracted vessel type without extensive post-processing.
+
+### Original Statistic Information
+
+| Dataset name | Anatomical region | Task type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
+| ------------------------------------------- | ----------------- | ------------ | ---------------------------- | ------------ | --------------------- | ---------------------- | ------------ | --------------------------------------------------------------- |
+| [Ravir](https://ravir.grand-challenge.org/) | eye | segmentation | infrared reflectance imaging | 3 | 23/-/19 | yes/-/- | 2022 | [CC-BY-NC 4.0](https://creativecommons.org/licenses/by-sa/4.0/) |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 23 | 87.22 | - | - | - | - |
+| artery | 23 | 5.45 | - | - | - | - |
+| vein | 23 | 7.33 | - | - | - | - |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![bac](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/infrared_reflectance_imaging/ravir/ravir_dataset.png)
+
+## Dataset Citation
+
+```bibtex
+@article{hatamizadeh2022ravir,
+ title={RAVIR: A dataset and methodology for the semantic segmentation and quantitative analysis of retinal arteries and veins in infrared reflectance imaging},
+ author={Hatamizadeh, Ali and Hosseini, Hamid and Patel, Niraj and Choi, Jinseo and Pole, Cameron C and Hoeferlin, Cory M and Schwartz, Steven D and Terzopoulos, Demetri},
+ journal={IEEE Journal of Biomedical and Health Informatics},
+ volume={26},
+ number={7},
+ pages={3272--3283},
+ year={2022},
+ publisher={IEEE}
+}
+```
+
+### Prerequisites
+
+- Python v3.8
+- PyTorch v1.10.0
+- pillow(PIL) v9.3.0
+- scikit-learn(sklearn) v1.2.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `ravir/` root directory, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset preparing
+
+- download dataset from [here](https://ravir.grand-challenge.org/) and decompression data to path `'data/ravir/'`.
+- run script `"python tools/prepare_dataset.py"` to split dataset and change folder structure as below.
+- run script `"python ../../tools/split_seg_dataset.py --data_root data/ravir"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set can't be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── infrared_reflectance_imaging
+ │ │ │ │ ├── ravir
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── val.txt
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ │ ├── test
+ │ │ │ │ | │ │ │ ├── yyy.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── yyy.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+```
+
+### Divided Dataset Information
+
+***Note: The table information below is divided by ourselves.***
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 18 | 87.41 | 5 | 86.53 | - | - |
+| artery | 18 | 5.44 | 5 | 5.50 | - | - |
+| vein | 18 | 7.15 | 5 | 7.97 | - | - |
+
+### Training commands
+
+To train models on a single server with one GPU. (default)
+
+```shell
+mim train mmseg ./configs/${CONFIG_PATH}
+```
+
+### Testing commands
+
+To train models on a single server with one GPU. (default)
+
+```shell
+mim test mmseg ./configs/${CONFIG_PATH} --checkpoint ${CHECKPOINT_PATH}
+```
+
+
+
+## Results
+
+### Ravir
+
+| Method | Backbone | Crop Size | lr | config |
+| :-------------: | :------: | :-------: | :----: | :------------------------------------------------------------------------: |
+| fcn_unet_s5-d16 | unet | 512x512 | 0.01 | [config](./configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_ravir-512x512.py) |
+| fcn_unet_s5-d16 | unet | 512x512 | 0.001 | [config](./configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_ravir-512x512.py) |
+| fcn_unet_s5-d16 | unet | 512x512 | 0.0001 | [config](./configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_ravir-512x512.py) |
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+
+ - [x] Basic docstrings & proper citation
+
+ - [x] Test-time correctness
+
+ - [x] A full README
+
+- [x] Milestone 2: Indicates a successful model implementation.
+
+ - [x] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+
+ - [ ] Unit tests
+
+ - [ ] Code polishing
+
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/infrared_reflectance_imaging/ravir/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_ravir-512x512.py b/projects/medical/2d_image/infrared_reflectance_imaging/ravir/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_ravir-512x512.py
new file mode 100755
index 00000000000..375ad5abf27
--- /dev/null
+++ b/projects/medical/2d_image/infrared_reflectance_imaging/ravir/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_ravir-512x512.py
@@ -0,0 +1,19 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './ravir_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.ravir_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ type='EncoderDecoder',
+ data_preprocessor=data_preprocessor,
+ pretrained=None,
+ decode_head=dict(num_classes=3),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/infrared_reflectance_imaging/ravir/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_ravir-512x512.py b/projects/medical/2d_image/infrared_reflectance_imaging/ravir/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_ravir-512x512.py
new file mode 100755
index 00000000000..a7ecf6dd453
--- /dev/null
+++ b/projects/medical/2d_image/infrared_reflectance_imaging/ravir/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_ravir-512x512.py
@@ -0,0 +1,19 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './ravir_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.ravir_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ type='EncoderDecoder',
+ data_preprocessor=dict(size=img_scale),
+ pretrained=None,
+ decode_head=dict(num_classes=3),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/infrared_reflectance_imaging/ravir/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_ravir-512x512.py b/projects/medical/2d_image/infrared_reflectance_imaging/ravir/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_ravir-512x512.py
new file mode 100755
index 00000000000..28556df53d1
--- /dev/null
+++ b/projects/medical/2d_image/infrared_reflectance_imaging/ravir/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_ravir-512x512.py
@@ -0,0 +1,19 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './ravir_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.ravir_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ type='EncoderDecoder',
+ data_preprocessor=data_preprocessor,
+ pretrained=None,
+ decode_head=dict(num_classes=3),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/infrared_reflectance_imaging/ravir/configs/ravir_512x512.py b/projects/medical/2d_image/infrared_reflectance_imaging/ravir/configs/ravir_512x512.py
new file mode 100755
index 00000000000..cb4c292d1f7
--- /dev/null
+++ b/projects/medical/2d_image/infrared_reflectance_imaging/ravir/configs/ravir_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'RAVIRDataset'
+data_root = 'data/ravir'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='val.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/infrared_reflectance_imaging/ravir/datasets/__init__.py b/projects/medical/2d_image/infrared_reflectance_imaging/ravir/datasets/__init__.py
new file mode 100755
index 00000000000..6f1d051bcff
--- /dev/null
+++ b/projects/medical/2d_image/infrared_reflectance_imaging/ravir/datasets/__init__.py
@@ -0,0 +1,3 @@
+from .ravir_dataset import RAVIRDataset
+
+__all__ = ['RAVIRDataset']
diff --git a/projects/medical/2d_image/infrared_reflectance_imaging/ravir/datasets/ravir_dataset.py b/projects/medical/2d_image/infrared_reflectance_imaging/ravir/datasets/ravir_dataset.py
new file mode 100755
index 00000000000..c9e0a8ed21e
--- /dev/null
+++ b/projects/medical/2d_image/infrared_reflectance_imaging/ravir/datasets/ravir_dataset.py
@@ -0,0 +1,28 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class RAVIRDataset(BaseSegDataset):
+ """RAVIRDataset dataset.
+
+ In segmentation map annotation for RAVIRDataset, 0 stands for background,
+ which is included in 3 categories. ``reduce_zero_label`` is fixed to
+ False. The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is
+ fixed to '.png'.
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ """
+ METAINFO = dict(classes=('background', 'artery', 'vein'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
diff --git a/projects/medical/2d_image/infrared_reflectance_imaging/ravir/tools/prepare_dataset.py b/projects/medical/2d_image/infrared_reflectance_imaging/ravir/tools/prepare_dataset.py
new file mode 100644
index 00000000000..068dcad8145
--- /dev/null
+++ b/projects/medical/2d_image/infrared_reflectance_imaging/ravir/tools/prepare_dataset.py
@@ -0,0 +1,33 @@
+import glob
+import os
+
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+
+# map = {255:2, 128:1, 0:0}
+
+os.makedirs('data/ravir/images/train', exist_ok=True)
+os.makedirs('data/ravir/images/test', exist_ok=True)
+os.makedirs('data/ravir/masks/train', exist_ok=True)
+
+os.system(
+ r'cp data/ravir/RAVIR\ Dataset/train/training_images/* data/ravir/images/train' # noqa
+)
+os.system(
+ r'cp data/ravir/RAVIR\ Dataset/train/training_masks/* data/ravir/masks/train' # noqa
+)
+os.system(r'cp data/ravir/RAVIR\ Dataset/test/* data/ravir/images/test')
+
+os.system(r'rm -rf data/ravir/RAVIR\ Dataset')
+
+imgs = glob.glob(os.path.join('data/ravir/masks/train', '*.png'))
+
+for im_path in tqdm(imgs):
+ im = Image.open(im_path)
+ imn = np.array(im)
+ imn[imn == 255] = 2
+ imn[imn == 128] = 1
+ imn[imn == 0] = 0
+ new_im = Image.fromarray(imn)
+ new_im.save(im_path)
diff --git a/projects/medical/2d_image/microscopy_images/2pm_vessel/README.md b/projects/medical/2d_image/microscopy_images/2pm_vessel/README.md
new file mode 100644
index 00000000000..1feb433a319
--- /dev/null
+++ b/projects/medical/2d_image/microscopy_images/2pm_vessel/README.md
@@ -0,0 +1,153 @@
+# 2-PM Vessel Dataset
+
+## Description
+
+This project supports **`2-PM Vessel Dataset`**, which can be downloaded from [here](https://opendatalab.org.cn/2-PM_Vessel_Dataset).
+
+### Dataset Overview
+
+An open-source volumetric brain vasculature dataset obtained with two-photon microscopy at Focused Ultrasound Lab, at Sunnybrook Research Institute (affiliated with University of Toronto by Dr. Alison Burgess, Charissa Poon and Marc Santos).
+
+The dataset contains a total of 12 volumetric stacks consisting images of mouse brain vasculature and tumor vasculature.
+
+### Information Statistics
+
+| Dataset Name | Anatomical Region | Task Type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
+| ------------------------------------------------------------ | ----------------- | ------------ | ----------------- | ------------ | --------------------- | ---------------------- | ------------ | ------------------------------------------------------------- |
+| [2pm_vessel](https://opendatalab.org.cn/2-PM_Vessel_Dataset) | vessel | segmentation | microscopy_images | 2 | 216/-/- | yes/-/- | 2021 | [CC0 1.0](https://creativecommons.org/publicdomain/zero/1.0/) |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 216 | 85.78 | - | - | - | - |
+| vessel | 180 | 14.22 | - | - | - | - |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![2pmv](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/histopathology/2pm_vessel/2pm_vessel_dataset.png?raw=true)
+
+### Dataset Citation
+
+```
+@article{teikari2016deep,
+ title={Deep learning convolutional networks for multiphoton microscopy vasculature segmentation},
+ author={Teikari, Petteri and Santos, Marc and Poon, Charissa and Hynynen, Kullervo},
+ journal={arXiv preprint arXiv:1606.02382},
+ year={2016}
+}
+```
+
+### Prerequisites
+
+- Python v3.8
+- PyTorch v1.10.0
+- pillow(PIL) v9.3.0
+- scikit-learn(sklearn) v1.2.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `2pm_vessel/` root directory, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset Preparing
+
+- download dataset from [here](https://opendatalab.org.cn/2-PM_Vessel_Dataset) and decompress data to path `'data/'`.
+- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below.
+- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set can't be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
+
+```shell
+mkdir data & cd data
+pip install opendatalab
+odl get 2-PM_Vessel_Dataset
+cd ..
+python tools/prepare_dataset.py
+python ../../tools/split_seg_dataset.py
+```
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── microscopy_images
+ │ │ │ │ ├── 2pm_vessel
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── val.txt
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+
+```
+
+### Divided Dataset Information
+
+***Note: The table information below is divided by ourselves.***
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 172 | 85.88 | 44 | 85.4 | - | - |
+| vessel | 142 | 14.12 | 38 | 14.6 | - | - |
+
+### Training commands
+
+To train models on a single server with one GPU. (default)
+
+```shell
+mim train mmseg ./configs/${CONFIG_FILE}
+```
+
+### Testing commands
+
+To test models on a single server with one GPU. (default)
+
+```shell
+mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH}
+```
+
+
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+ - [x] Basic docstrings & proper citation
+ - [ ] Test-time correctness
+ - [x] A full README
+
+- [ ] Milestone 2: Indicates a successful model implementation.
+
+ - [ ] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+ - [ ] Unit tests
+ - [ ] Code polishing
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/microscopy_images/2pm_vessel/configs/2pm-vessel_512x512.py b/projects/medical/2d_image/microscopy_images/2pm_vessel/configs/2pm-vessel_512x512.py
new file mode 100644
index 00000000000..124403fa973
--- /dev/null
+++ b/projects/medical/2d_image/microscopy_images/2pm_vessel/configs/2pm-vessel_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'TwoPMVesselDataset'
+data_root = 'data/'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='val.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/microscopy_images/2pm_vessel/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_2pm-vessel-512x512.py b/projects/medical/2d_image/microscopy_images/2pm_vessel/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_2pm-vessel-512x512.py
new file mode 100644
index 00000000000..2a429e9068e
--- /dev/null
+++ b/projects/medical/2d_image/microscopy_images/2pm_vessel/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_2pm-vessel-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './2pm-vessel_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.2pm-vessel_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/microscopy_images/2pm_vessel/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_2pm-vessel-512x512.py b/projects/medical/2d_image/microscopy_images/2pm_vessel/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_2pm-vessel-512x512.py
new file mode 100644
index 00000000000..10d9bb82f25
--- /dev/null
+++ b/projects/medical/2d_image/microscopy_images/2pm_vessel/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_2pm-vessel-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './2pm-vessel_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.2pm-vessel_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/microscopy_images/2pm_vessel/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_2pm-vessel-512x512.py b/projects/medical/2d_image/microscopy_images/2pm_vessel/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_2pm-vessel-512x512.py
new file mode 100644
index 00000000000..65c1579ec71
--- /dev/null
+++ b/projects/medical/2d_image/microscopy_images/2pm_vessel/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_2pm-vessel-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './2pm-vessel_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.2pm-vessel_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/microscopy_images/2pm_vessel/configs/fcn-unet-s5-d16_unet_1xb16-0.01lr-sigmoid-20k_bactteria-detection-512x512.py b/projects/medical/2d_image/microscopy_images/2pm_vessel/configs/fcn-unet-s5-d16_unet_1xb16-0.01lr-sigmoid-20k_bactteria-detection-512x512.py
new file mode 100644
index 00000000000..91ed6ada3f2
--- /dev/null
+++ b/projects/medical/2d_image/microscopy_images/2pm_vessel/configs/fcn-unet-s5-d16_unet_1xb16-0.01lr-sigmoid-20k_bactteria-detection-512x512.py
@@ -0,0 +1,18 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './2pm-vessel_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.2pm-vessel_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(
+ num_classes=2, loss_decode=dict(use_sigmoid=True), out_channels=1),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/microscopy_images/2pm_vessel/datasets/2pm-vessel_dataset.py b/projects/medical/2d_image/microscopy_images/2pm_vessel/datasets/2pm-vessel_dataset.py
new file mode 100644
index 00000000000..984b5a1361d
--- /dev/null
+++ b/projects/medical/2d_image/microscopy_images/2pm_vessel/datasets/2pm-vessel_dataset.py
@@ -0,0 +1,31 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class TwoPMVesselDataset(BaseSegDataset):
+ """TwoPMVesselDataset dataset.
+
+ In segmentation map annotation for TwoPMVesselDataset,
+ 0 stands for background, which is included in 2 categories.
+ ``reduce_zero_label`` is fixed to False. The ``img_suffix``
+ is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
+
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
+ Default to False.
+ """
+ METAINFO = dict(classes=('background', 'vessel'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
diff --git a/projects/medical/2d_image/microscopy_images/2pm_vessel/tools/prepare_dataset.py b/projects/medical/2d_image/microscopy_images/2pm_vessel/tools/prepare_dataset.py
new file mode 100644
index 00000000000..1b46af2cad0
--- /dev/null
+++ b/projects/medical/2d_image/microscopy_images/2pm_vessel/tools/prepare_dataset.py
@@ -0,0 +1,46 @@
+import os
+
+import tifffile as tiff
+from PIL import Image
+
+root_path = 'data/'
+
+image_dir = os.path.join(root_path,
+ '2-PM_Vessel_Dataset/raw/vesselNN_dataset/denoised')
+label_dir = os.path.join(root_path,
+ '2-PM_Vessel_Dataset/raw/vesselNN_dataset/labels')
+tgt_img_train_dir = os.path.join(root_path, 'images/train/')
+tgt_mask_train_dir = os.path.join(root_path, 'masks/train/')
+os.system('mkdir -p ' + tgt_img_train_dir)
+os.system('mkdir -p ' + tgt_mask_train_dir)
+
+
+def filter_suffix(src_dir, suffix):
+ suffix = '.' + suffix if '.' not in suffix else suffix
+ file_names = [_ for _ in os.listdir(src_dir) if _.endswith(suffix)]
+ file_paths = [os.path.join(src_dir, _) for _ in file_names]
+ return sorted(file_paths), sorted(file_names)
+
+
+if __name__ == '__main__':
+
+ image_path_list, _ = filter_suffix(image_dir, suffix='tif')
+ label_path_list, _ = filter_suffix(label_dir, suffix='.tif')
+
+ for img_path, label_path in zip(image_path_list, label_path_list):
+ labels = tiff.imread(label_path)
+ images = tiff.imread(img_path)
+ assert labels.ndim == 3
+ assert images.shape == labels.shape
+ name = img_path.split('/')[-1].replace('.tif', '')
+ # a single .tif file contains multiple slices
+ # as long as it is read by tifffile package.
+ for i in range(labels.shape[0]):
+ slice_name = name + '_' + str(i).rjust(3, '0') + '.png'
+ image = images[i]
+ label = labels[i] // 255
+
+ save_path_label = os.path.join(tgt_mask_train_dir, slice_name)
+ Image.fromarray(label).save(save_path_label)
+ save_path_image = os.path.join(tgt_img_train_dir, slice_name)
+ Image.fromarray(image).convert('RGB').save(save_path_image)
diff --git a/projects/medical/2d_image/microscopy_images/bactteria_detection/README.md b/projects/medical/2d_image/microscopy_images/bactteria_detection/README.md
new file mode 100644
index 00000000000..1cedda715a7
--- /dev/null
+++ b/projects/medical/2d_image/microscopy_images/bactteria_detection/README.md
@@ -0,0 +1,160 @@
+# Bactteria detection with darkfield microscopy
+
+## Description
+
+This project supports **`Bactteria detection with darkfield microscopy`**, which can be downloaded from [here](https://tianchi.aliyun.com/dataset/94411).
+
+### Dataset Overview
+
+Spirochaeta is a genus of bacteria classified within the phylum Spirochaetes. Included in this dataset are 366 darkfield microscopy images and manually annotated masks which can be used for classification and segmentation purposes. Detecting bacteria in blood could have a huge significance for research in both the medical and computer science field.
+
+It was gathered and annotated by students (hand-on experience)
+It has more than just one targeted class (blood cell and bacteria were annotated)
+It is highly imbalanced, so naive loss functions would work less properly
+
+### Original Statistic Information
+
+| Dataset name | Anatomical region | Task type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
+| --------------------------------------------------------------- | ----------------- | ------------ | ---------- | ------------ | --------------------- | ---------------------- | ------------ | --------------------------------------------------------------- |
+| [Bactteria detection](https://tianchi.aliyun.com/dataset/94411) | bacteria | segmentation | microscopy | 3 | 366/-/- | yes/-/- | 2017 | [CC-BY-NC 4.0](https://creativecommons.org/licenses/by-sa/4.0/) |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :----------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 366 | 85.9 | - | - | - | - |
+| erythrocytes | 345 | 13.03 | - | - | - | - |
+| spirochaete | 288 | 1.07 | - | - | - | - |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![bac](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/microscopy_images/bactteria_detection/bactteria_detection_dataset.png)
+
+## Usage
+
+### Prerequisites
+
+- Python v3.8
+- PyTorch v1.10.0
+- pillow (PIL) v9.3.0
+- scikit-learn (sklearn) v1.2.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `bactteria_detection/` root directory, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset Preparing
+
+- Download dataset from [here](https://tianchi.aliyun.com/dataset/94411) and save it to the `data/` directory .
+- Decompress data to path `data/`. This will create a new folder named `data/Bacteria_detection_with_darkfield_microscopy_datasets/`, which contains the original image data.
+- run script `python tools/prepare_dataset.py` to format data and change folder structure as below.
+- run script `python ../../tools/split_seg_dataset.py` to split dataset. For the Bacteria_detection dataset, as there is no test or validation dataset, we sample 20% samples from the whole dataset as the validation dataset and 80% samples for training data and make two filename lists `train.txt` and `val.txt`. As we set the random seed as the hard code, we eliminated the randomness, the dataset split actually can be reproducible.
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── microscopy_images
+ │ │ │ │ ├── bactteria_detection
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── val.txt
+ │ │ │ │ │ │ ├── Bacteria_detection_with_darkfield_microscopy_datasets
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+```
+
+### Divided Dataset Information
+
+***Note: The table information below is divided by ourselves.***
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :----------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 292 | 85.66 | 74 | 86.7 | - | - |
+| erythrocytes | 274 | 13.25 | 71 | 12.29 | - | - |
+| spirochaete | 231 | 1.09 | 57 | 1.01 | - | - |
+
+### Training commands
+
+Train models on a single server with one GPU.
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+mim train mmseg ./configs/${CONFIG_FILE}
+```
+
+### Testing commands
+
+Test models on a single server with one GPU.
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH}
+```
+
+
+
+## Results
+
+### Bactteria detection with darkfield microscopy
+
+***Note: The following experimental results are based on the data randomly partitioned according to the above method described in the dataset preparing section.***
+
+| Method | Backbone | Crop Size | lr | mIoU | mDice | config | download |
+| :-------------: | :------: | :-------: | :----: | :---: | :---: | :--------------------------------------------------------------------------------------: | :----------------------: |
+| fcn_unet_s5-d16 | unet | 512x512 | 0.01 | 76.48 | 84.68 | [config](./configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_bactteria-detection-512x512.py) | [model](<>) \| [log](<>) |
+| fcn_unet_s5-d16 | unet | 512x512 | 0.001 | 61.06 | 63.69 | [config](./configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_bactteria-detection-512x512.py) | [model](<>) \| [log](<>) |
+| fcn_unet_s5-d16 | unet | 512x512 | 0.0001 | 58.87 | 62.42 | [config](./configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_bactteria-detection-512x512.py) | [model](<>) \| [log](<>) |
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+
+ - [x] Basic docstrings & proper citation
+
+ - [x] Test-time correctness
+
+ - [x] A full README
+
+- [x] Milestone 2: Indicates a successful model implementation.
+
+ - [x] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+
+ - [ ] Unit tests
+
+ - [ ] Code polishing
+
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/microscopy_images/bactteria_detection/configs/bactteria-detection_512x512.py b/projects/medical/2d_image/microscopy_images/bactteria_detection/configs/bactteria-detection_512x512.py
new file mode 100644
index 00000000000..e3eab4e3868
--- /dev/null
+++ b/projects/medical/2d_image/microscopy_images/bactteria_detection/configs/bactteria-detection_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'BactteriaDetectionDataset'
+data_root = 'data/'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='val.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/microscopy_images/bactteria_detection/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_bactteria-detection-512x512.py b/projects/medical/2d_image/microscopy_images/bactteria_detection/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_bactteria-detection-512x512.py
new file mode 100644
index 00000000000..ede58d785c1
--- /dev/null
+++ b/projects/medical/2d_image/microscopy_images/bactteria_detection/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_bactteria-detection-512x512.py
@@ -0,0 +1,18 @@
+_base_ = [
+ './bactteria-detection_512x512.py',
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.bactteria-detection_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=3),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/microscopy_images/bactteria_detection/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_bactteria-detection-512x512.py b/projects/medical/2d_image/microscopy_images/bactteria_detection/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_bactteria-detection-512x512.py
new file mode 100644
index 00000000000..bde3fa14acd
--- /dev/null
+++ b/projects/medical/2d_image/microscopy_images/bactteria_detection/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_bactteria-detection-512x512.py
@@ -0,0 +1,18 @@
+_base_ = [
+ './bactteria-detection_512x512.py',
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.bactteria-detection_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=3),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/microscopy_images/bactteria_detection/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_bactteria-detection-512x512.py b/projects/medical/2d_image/microscopy_images/bactteria_detection/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_bactteria-detection-512x512.py
new file mode 100644
index 00000000000..08e204f3809
--- /dev/null
+++ b/projects/medical/2d_image/microscopy_images/bactteria_detection/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_bactteria-detection-512x512.py
@@ -0,0 +1,18 @@
+_base_ = [
+ './bactteria-detection_512x512.py',
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.bactteria-detection_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=3),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/microscopy_images/bactteria_detection/datasets/bactteria-detection_dataset.py b/projects/medical/2d_image/microscopy_images/bactteria_detection/datasets/bactteria-detection_dataset.py
new file mode 100644
index 00000000000..c95097b1acd
--- /dev/null
+++ b/projects/medical/2d_image/microscopy_images/bactteria_detection/datasets/bactteria-detection_dataset.py
@@ -0,0 +1,27 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class BactteriaDetectionDataset(BaseSegDataset):
+ """BactteriaDetectionDataset dataset.
+
+ In segmentation map annotation for BactteriaDetectionDataset,
+ ``reduce_zero_label`` is fixed to False. The ``img_suffix``
+ is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
+
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ """
+ METAINFO = dict(classes=('background', 'erythrocytes', 'spirochaete'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=False,
+ **kwargs)
diff --git a/projects/medical/2d_image/microscopy_images/bactteria_detection/tools/prepare_dataset.py b/projects/medical/2d_image/microscopy_images/bactteria_detection/tools/prepare_dataset.py
new file mode 100755
index 00000000000..8dcc719e262
--- /dev/null
+++ b/projects/medical/2d_image/microscopy_images/bactteria_detection/tools/prepare_dataset.py
@@ -0,0 +1,33 @@
+import glob
+import os
+import shutil
+
+from PIL import Image
+
+root_path = 'data/'
+img_suffix = '.png'
+seg_map_suffix = '.png'
+save_img_suffix = '.png'
+save_seg_map_suffix = '.png'
+
+x_train = glob.glob(
+ 'data/Bacteria_detection_with_darkfield_microscopy_datasets/images/*' +
+ img_suffix) # noqa
+
+os.system('mkdir -p ' + root_path + 'images/train/')
+os.system('mkdir -p ' + root_path + 'masks/train/')
+
+part_dir_dict = {0: 'train/'}
+for ith, part in enumerate([x_train]):
+ part_dir = part_dir_dict[ith]
+ for img in part:
+ basename = os.path.basename(img)
+ img_save_path = os.path.join(root_path, 'images', part_dir,
+ basename.split('.')[0] + save_img_suffix)
+ shutil.copy(img, img_save_path)
+ mask_path = 'data/Bacteria_detection_with_darkfield_microscopy_datasets/masks/' + basename # noqa
+ mask = Image.open(mask_path).convert('L')
+ mask_save_path = os.path.join(
+ root_path, 'masks', part_dir,
+ basename.split('.')[0] + save_seg_map_suffix)
+ mask.save(mask_save_path)
diff --git a/projects/medical/2d_image/tools/split_seg_dataset.py b/projects/medical/2d_image/tools/split_seg_dataset.py
new file mode 100644
index 00000000000..9ab2e9282fa
--- /dev/null
+++ b/projects/medical/2d_image/tools/split_seg_dataset.py
@@ -0,0 +1,42 @@
+import argparse
+import glob
+import os
+
+from sklearn.model_selection import train_test_split
+
+
+def save_anno(img_list, file_path, remove_suffix=True):
+ if remove_suffix:
+ img_list = [
+ '/'.join(img_path.split('/')[-2:]) for img_path in img_list
+ ]
+ img_list = [
+ '.'.join(img_path.split('.')[:-1]) for img_path in img_list
+ ]
+ with open(file_path, 'w') as file_:
+ for x in list(img_list):
+ file_.write(x + '\n')
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--data_root', default='data/')
+ args = parser.parse_args()
+ data_root = args.data_root
+ if os.path.exists(os.path.join(data_root, 'masks/val')):
+ x_val = sorted(glob.glob(data_root + '/images/val/*.png'))
+ save_anno(x_val, data_root + '/val.txt')
+ if os.path.exists(os.path.join(data_root, 'masks/test')):
+ x_test = sorted(glob.glob(data_root + '/images/test/*.png'))
+ save_anno(x_test, data_root + '/test.txt')
+ if not os.path.exists(os.path.join(
+ data_root, 'masks/val')) and not os.path.exists(
+ os.path.join(data_root, 'masks/test')):
+ all_imgs = sorted(glob.glob(data_root + '/images/train/*.png'))
+ x_train, x_val = train_test_split(
+ all_imgs, test_size=0.2, random_state=0)
+ save_anno(x_train, data_root + '/train.txt')
+ save_anno(x_val, data_root + '/val.txt')
+ else:
+ x_train = sorted(glob.glob(data_root + '/images/train/*.png'))
+ save_anno(x_train, data_root + '/train.txt')
diff --git a/projects/medical/2d_image/x_ray/chest_image_pneum/README.md b/projects/medical/2d_image/x_ray/chest_image_pneum/README.md
new file mode 100644
index 00000000000..a1cd27ba455
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/chest_image_pneum/README.md
@@ -0,0 +1,147 @@
+# Chest Image Dataset for Pneumothorax Segmentation
+
+## Description
+
+This project supports **`Chest Image Dataset for Pneumothorax Segmentation`**, which can be downloaded from [here](https://tianchi.aliyun.com/dataset/83075).
+
+### Dataset Overview
+
+Pneumothorax can be caused by a blunt chest injury, damage from underlying lung disease, or most horrifying—it may occur for no obvious reason at all. On some occasions, a collapsed lung can be a life-threatening event.
+Pneumothorax is usually diagnosed by a radiologist on a chest x-ray, and can sometimes be very difficult to confirm. An accurate AI algorithm to detect pneumothorax would be useful in a lot of clinical scenarios. AI could be used to triage chest radiographs for priority interpretation, or to provide a more confident diagnosis for non-radiologists.
+
+The dataset is provided by the Society for Imaging Informatics in Medicine(SIIM), American College of Radiology (ACR),Society of Thoracic Radiology (STR) and MD.ai. You can develop a model to classify (and if present, segment) pneumothorax from a set of chest radiographic images. If successful, you could aid in the early recognition of pneumothoraces and save lives.
+
+### Original Statistic Information
+
+| Dataset name | Anatomical region | Task type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
+| --------------------------------------------------------------------- | ----------------- | ------------ | -------- | ------------ | --------------------- | ---------------------- | ------------ | ------------------------------------------------------------------ |
+| [pneumothorax segmentation](https://tianchi.aliyun.com/dataset/83075) | thorax | segmentation | x_ray | 2 | 12089/-/3205 | yes/-/no | - | [CC-BY-SA-NC 4.0](https://creativecommons.org/licenses/by-sa/4.0/) |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :---------------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| normal | 12089 | 99.75 | - | - | - | - |
+| pneumothorax area | 2669 | 0.25 | - | - | - | - |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![bac](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/x_ray/chest_image_pneum/chest_image_pneum_dataset.png)
+
+### Prerequisites
+
+- Python v3.8
+- PyTorch v1.10.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `chest_image_pneum/` root directory, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset preparing
+
+- download dataset from [here](https://tianchi.aliyun.com/dataset/83075) and decompress data to path `'data/'`.
+- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below.
+- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set can't be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── x_ray
+ │ │ │ │ ├── chest_image_pneum
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── test.txt
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+```
+
+### Divided Dataset Information
+
+***Note: The table information below is divided by ourselves.***
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :---------------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| normal | 9637 | 99.75 | 2410 | 99.74 | - | - |
+| pneumothorax area | 2137 | 0.25 | 532 | 0.26 | - | - |
+
+### Training commands
+
+Train models on a single server with one GPU.
+
+```shell
+mim train mmseg ./configs/${CONFIG_FILE}
+```
+
+### Testing commands
+
+Test models on a single server with one GPU.
+
+```shell
+mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH}
+```
+
+
+
+## Results
+
+### Bactteria detection with darkfield microscopy
+
+| Method | Backbone | Crop Size | lr | mIoU | mDice | config | download |
+| :-------------: | :------: | :-------: | :----: | :--: | :---: | :------------------------------------------------------------------------------------: | :----------------------: |
+| fcn_unet_s5-d16 | unet | 512x512 | 0.01 | - | - | [config](./configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_chest-image-pneum-512x512.py) | [model](<>) \| [log](<>) |
+| fcn_unet_s5-d16 | unet | 512x512 | 0.001 | - | - | [config](./configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_chest-image-pneum-512x512.py) | [model](<>) \| [log](<>) |
+| fcn_unet_s5-d16 | unet | 512x512 | 0.0001 | - | - | [config](./configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_chest-image-pneum-512x512.py) | [model](<>) \| [log](<>) |
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+
+ - [x] Basic docstrings & proper citation
+
+ - [x] Test-time correctness
+
+ - [x] A full README
+
+- [x] Milestone 2: Indicates a successful model implementation.
+
+ - [x] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+
+ - [ ] Unit tests
+
+ - [ ] Code polishing
+
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/x_ray/chest_image_pneum/configs/chest-image-pneum_512x512.py b/projects/medical/2d_image/x_ray/chest_image_pneum/configs/chest-image-pneum_512x512.py
new file mode 100644
index 00000000000..411229bd413
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/chest_image_pneum/configs/chest-image-pneum_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'ChestImagePneumDataset'
+data_root = 'data/'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='val.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/x_ray/chest_image_pneum/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_chest-image-pneum-512x512.py b/projects/medical/2d_image/x_ray/chest_image_pneum/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_chest-image-pneum-512x512.py
new file mode 100644
index 00000000000..0f26459467c
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/chest_image_pneum/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_chest-image-pneum-512x512.py
@@ -0,0 +1,18 @@
+_base_ = [
+ './chest-image-pneum_512x512.py',
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.chest-image-pneum_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/x_ray/chest_image_pneum/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_chest-image-pneum-512x512.py b/projects/medical/2d_image/x_ray/chest_image_pneum/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_chest-image-pneum-512x512.py
new file mode 100644
index 00000000000..37b91889d8a
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/chest_image_pneum/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_chest-image-pneum-512x512.py
@@ -0,0 +1,18 @@
+_base_ = [
+ './chest-image-pneum_512x512.py',
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.chest-image-pneum_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/x_ray/chest_image_pneum/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_chest-image-pneum-512x512.py b/projects/medical/2d_image/x_ray/chest_image_pneum/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_chest-image-pneum-512x512.py
new file mode 100644
index 00000000000..379e8181f3d
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/chest_image_pneum/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_chest-image-pneum-512x512.py
@@ -0,0 +1,18 @@
+_base_ = [
+ './chest-image-pneum_512x512.py',
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.chest-image-pneum_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/x_ray/chest_image_pneum/datasets/chest-image-pneum_dataset.py b/projects/medical/2d_image/x_ray/chest_image_pneum/datasets/chest-image-pneum_dataset.py
new file mode 100644
index 00000000000..aeee60ae92e
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/chest_image_pneum/datasets/chest-image-pneum_dataset.py
@@ -0,0 +1,27 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class ChestImagePneumDataset(BaseSegDataset):
+ """ChestImagePneumDataset dataset.
+
+ In segmentation map annotation for ChestImagePneumDataset,
+ ``reduce_zero_label`` is fixed to False. The ``img_suffix``
+ is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
+
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ """
+ METAINFO = dict(classes=('normal', 'pneumothorax area'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=False,
+ **kwargs)
diff --git a/projects/medical/2d_image/x_ray/chest_image_pneum/tools/prepare_dataset.py b/projects/medical/2d_image/x_ray/chest_image_pneum/tools/prepare_dataset.py
new file mode 100755
index 00000000000..47eddc96dc5
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/chest_image_pneum/tools/prepare_dataset.py
@@ -0,0 +1,73 @@
+import os
+
+import numpy as np
+import pandas as pd
+import pydicom
+from PIL import Image
+
+root_path = 'data/'
+img_suffix = '.dcm'
+seg_map_suffix = '.png'
+save_img_suffix = '.png'
+save_seg_map_suffix = '.png'
+
+x_train = []
+for fpath, dirname, fnames in os.walk('data/chestimage_train_datasets'):
+ for fname in fnames:
+ if fname.endswith('.dcm'):
+ x_train.append(os.path.join(fpath, fname))
+x_test = []
+for fpath, dirname, fnames in os.walk('data/chestimage_test_datasets/'):
+ for fname in fnames:
+ if fname.endswith('.dcm'):
+ x_test.append(os.path.join(fpath, fname))
+
+os.system('mkdir -p ' + root_path + 'images/train/')
+os.system('mkdir -p ' + root_path + 'images/test/')
+os.system('mkdir -p ' + root_path + 'masks/train/')
+
+
+def rle_decode(rle, width, height):
+ mask = np.zeros(width * height, dtype=np.uint8)
+ array = np.asarray([int(x) for x in rle.split()])
+ starts = array[0::2]
+ lengths = array[1::2]
+
+ current_position = 0
+ for index, start in enumerate(starts):
+ current_position += start
+ mask[current_position:current_position + lengths[index]] = 1
+ current_position += lengths[index]
+
+ return mask.reshape(width, height, order='F')
+
+
+part_dir_dict = {0: 'train/', 1: 'test/'}
+dict_from_csv = pd.read_csv(
+ root_path + 'chestimage_train-rle_datasets.csv', sep=',',
+ index_col=0).to_dict()[' EncodedPixels']
+
+for ith, part in enumerate([x_train, x_test]):
+ part_dir = part_dir_dict[ith]
+ for img in part:
+ basename = os.path.basename(img)
+ img_id = '.'.join(basename.split('.')[:-1])
+ if ith == 0 and (img_id not in dict_from_csv.keys()):
+ continue
+ image = pydicom.read_file(img).pixel_array
+ save_img_path = root_path + 'images/' + part_dir + '.'.join(
+ basename.split('.')[:-1]) + save_img_suffix
+ print(save_img_path)
+ img_h, img_w = image.shape[:2]
+ image = Image.fromarray(image)
+ image.save(save_img_path)
+ if ith == 1:
+ continue
+ if dict_from_csv[img_id] == '-1':
+ mask = np.zeros((img_h, img_w), dtype=np.uint8)
+ else:
+ mask = rle_decode(dict_from_csv[img_id], img_h, img_w)
+ save_mask_path = root_path + 'masks/' + part_dir + '.'.join(
+ basename.split('.')[:-1]) + save_seg_map_suffix
+ mask = Image.fromarray(mask)
+ mask.save(save_mask_path)
diff --git a/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/README.md b/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/README.md
new file mode 100644
index 00000000000..7cb099c8a43
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/README.md
@@ -0,0 +1,119 @@
+# Chest X-ray Images with Pneumothorax Masks
+
+## Description
+
+This project support **`Chest X-ray Images with Pneumothorax Masks `**, and the dataset used in this project can be downloaded from [here](https://www.kaggle.com/datasets/vbookshelf/pneumothorax-chest-xray-images-and-masks).
+
+### Dataset Overview
+
+A pneumothorax (noo-moe-THOR-aks) is a collapsed lung. A pneumothorax occurs when air leaks into the space between your lung and chest wall. This air pushes on the outside of your lung and makes it collapse. Pneumothorax can be a complete lung collapse or a collapse of only a portion of the lung.
+
+A pneumothorax can be caused by a blunt or penetrating chest injury, certain medical procedures, or damage from underlying lung disease. Or it may occur for no obvious reason. Symptoms usually include sudden chest pain and shortness of breath. On some occasions, a collapsed lung can be a life-threatening event.
+
+Treatment for a pneumothorax usually involves inserting a needle or chest tube between the ribs to remove the excess air. However, a small pneumothorax may heal on its own.
+
+### Statistic Information
+
+| Dataset Name | Anatomical Region | Task type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release date | License |
+| --------------------------------------------------------------------------------------------------------------------------------- | ----------------- | ------------ | -------- | ------------ | --------------------- | ---------------------- | ------------ | --------------------------------------------------------------- |
+| [Chest-x-ray-images-with-pneumothorax-masks](https://www.kaggle.com/datasets/vbookshelf/pneumothorax-chest-xray-images-and-masks) | throax | segmentation | x_ray | 2 | 10675/-/1372 | yes/-/yes | 2020 | [CC-BY-NC 4.0](https://creativecommons.org/licenses/by-sa/4.0/) |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :----------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 10675 | 99.7 | - | - | 1372 | 99.71 |
+| pneumothroax | 2379 | 0.3 | - | - | 290 | 0.29 |
+
+### Visualization
+
+![chest_x_ray_images_with_pneumothorax_masks](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/x_ray/chest_x_ray_images_with_pneumothorax_masks/chest_x_ray_images_with_pneumothorax_masks_dataset.png?raw=true)
+
+### Prerequisites
+
+- Python 3.8
+- PyTorch 1.10.0
+- pillow(PIL) 9.3.0
+- scikit-learn(sklearn) 1.2.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of PYTHONPATH, which should point to the project's directory so that Python can locate the module files. In chest_x_ray_images_with_pneumothorax_masks/ root directory, run the following line to add the current directory to PYTHONPATH:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset preparing
+
+- download dataset from [here](https://www.kaggle.com/datasets/vbookshelf/pneumothorax-chest-xray-images-and-masks) and decompression data to path 'data/'.
+- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below.
+- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set cannot be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── x_ray
+ │ │ │ │ ├── chest_x_ray_images_with_pneumothorax_masks
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── val.txt
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+```
+
+### Training commands
+
+```shell
+mim train mmseg ./configs/${CONFIG_PATH}
+```
+
+To train on multiple GPUs, e.g. 8 GPUs, run the following command:
+
+```shell
+mim train mmseg ./configs/${CONFIG_PATH} --launcher pytorch --gpus 8
+```
+
+### Testing commands
+
+```shell
+mim test mmseg ./configs/${CONFIG_PATH} --checkpoint ${CHECKPOINT_PATH}
+```
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+ - [x] Basic docstrings & proper citation
+ - [x] Test-time correctness
+ - [x] A full README
+
+- [x] Milestone 2: Indicates a successful model implementation.
+
+ - [x] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+ - [ ] Unit tests
+ - [ ] Code polishing
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/configs/chest-x-ray-images-with-pneumothorax-masks_512x512.py b/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/configs/chest-x-ray-images-with-pneumothorax-masks_512x512.py
new file mode 100644
index 00000000000..96676de8616
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/configs/chest-x-ray-images-with-pneumothorax-masks_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'ChestPenumoMaskDataset'
+data_root = 'data/'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='val.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_chest-x-ray-images-with-pneumothorax-masks-512x512.py b/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_chest-x-ray-images-with-pneumothorax-masks-512x512.py
new file mode 100644
index 00000000000..76c214d04cb
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_chest-x-ray-images-with-pneumothorax-masks-512x512.py
@@ -0,0 +1,20 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ './chest-x-ray-images-with-pneumothorax-masks_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(
+ imports='datasets.chest-x-ray-images-with-pneumothorax-masks_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(
+ num_classes=2, loss_decode=dict(use_sigmoid=True), out_channels=1),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_chest-x-ray-images-with-pneumothorax-masks-512x512.py b/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_chest-x-ray-images-with-pneumothorax-masks-512x512.py
new file mode 100644
index 00000000000..066996dda94
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_chest-x-ray-images-with-pneumothorax-masks-512x512.py
@@ -0,0 +1,19 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ './chest-x-ray-images-with-pneumothorax-masks_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(
+ imports='datasets.chest-x-ray-images-with-pneumothorax-masks_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_chest-x-ray-images-with-pneumothorax-masks-512x512.py b/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_chest-x-ray-images-with-pneumothorax-masks-512x512.py
new file mode 100644
index 00000000000..a7065b82315
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_chest-x-ray-images-with-pneumothorax-masks-512x512.py
@@ -0,0 +1,19 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ './chest-x-ray-images-with-pneumothorax-masks_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(
+ imports='datasets.chest-x-ray-images-with-pneumothorax-masks_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_chest-x-ray-images-with-pneumothorax-masks-512x512.py b/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_chest-x-ray-images-with-pneumothorax-masks-512x512.py
new file mode 100644
index 00000000000..e5682ee76b7
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_chest-x-ray-images-with-pneumothorax-masks-512x512.py
@@ -0,0 +1,19 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py',
+ './chest-x-ray-images-with-pneumothorax-masks_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(
+ imports='datasets.chest-x-ray-images-with-pneumothorax-masks_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/datasets/chest-x-ray-images-with-pneumothorax-masks_dataset.py b/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/datasets/chest-x-ray-images-with-pneumothorax-masks_dataset.py
new file mode 100644
index 00000000000..d32f597a5a0
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/datasets/chest-x-ray-images-with-pneumothorax-masks_dataset.py
@@ -0,0 +1,31 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class ChestPenumoMaskDataset(BaseSegDataset):
+ """ChestPenumoMaskDataset dataset.
+
+ In segmentation map annotation for ChestPenumoMaskDataset,
+ 0 stands for background, which is included in 2 categories.
+ ``reduce_zero_label`` is fixed to False. The ``img_suffix``
+ is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
+
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
+ Default to False.
+ """
+ METAINFO = dict(classes=('background', 'penumothroax'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
diff --git a/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/tools/prepare_dataset.py b/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/tools/prepare_dataset.py
new file mode 100644
index 00000000000..c7de1f19042
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/chest_x_ray_images_with_pneumothorax_masks/tools/prepare_dataset.py
@@ -0,0 +1,36 @@
+import glob
+import os
+import shutil
+
+from PIL import Image
+from sklearn.model_selection import train_test_split
+
+root_path = 'data/'
+img_suffix = '.png'
+seg_map_suffix = '.png'
+save_img_suffix = '.png'
+save_seg_map_suffix = '.png'
+
+all_imgs = glob.glob('data/siim-acr-pneumothorax/png_images/*' + img_suffix)
+x_train, x_test = train_test_split(all_imgs, test_size=0.2, random_state=0)
+
+print(len(x_train), len(x_test))
+os.system('mkdir -p ' + root_path + 'images/train/')
+os.system('mkdir -p ' + root_path + 'images/val/')
+os.system('mkdir -p ' + root_path + 'masks/train/')
+os.system('mkdir -p ' + root_path + 'masks/val/')
+
+part_dir_dict = {0: 'train/', 1: 'val/'}
+for ith, part in enumerate([x_train, x_test]):
+ part_dir = part_dir_dict[ith]
+ for img in part:
+ basename = os.path.basename(img)
+ img_save_path = os.path.join(root_path, 'images', part_dir,
+ basename.split('.')[0] + save_img_suffix)
+ shutil.copy(img, img_save_path)
+ mask_path = 'data/siim-acr-pneumothorax/png_masks/' + basename
+ mask = Image.open(mask_path).convert('L')
+ mask_save_path = os.path.join(
+ root_path, 'masks', part_dir,
+ basename.split('.')[0] + save_seg_map_suffix)
+ mask.save(mask_save_path)
diff --git a/projects/medical/2d_image/x_ray/covid_19_ct_cxr/README.md b/projects/medical/2d_image/x_ray/covid_19_ct_cxr/README.md
new file mode 100644
index 00000000000..8469219effa
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/covid_19_ct_cxr/README.md
@@ -0,0 +1,158 @@
+# Covid-19 CT Chest X-ray Dataset
+
+## Description
+
+This project supports **`Covid-19 CT Chest X-ray Dataset`**, which can be downloaded from [here](https://github.com/ieee8023/covid-chestxray-dataset).
+
+### Dataset Overview
+
+In the context of a COVID-19 pandemic, we want to improve prognostic predictions to triage and manage patient care. Data is the first step to developing any diagnostic/prognostic tool. While there exist large public datasets of more typical chest X-rays from the NIH \[Wang 2017\], Spain \[Bustos 2019\], Stanford \[Irvin 2019\], MIT \[Johnson 2019\] and Indiana University \[Demner-Fushman 2016\], there is no collection of COVID-19 chest X-rays or CT scans designed to be used for computational analysis.
+
+The 2019 novel coronavirus (COVID-19) presents several unique features [Fang, 2020](https://pubs.rsna.org/doi/10.1148/radiol.2020200432) and [Ai 2020](https://pubs.rsna.org/doi/10.1148/radiol.2020200642). While the diagnosis is confirmed using polymerase chain reaction (PCR), infected patients with pneumonia may present on chest X-ray and computed tomography (CT) images with a pattern that is only moderately characteristic for the human eye [Ng, 2020](https://pubs.rsna.org/doi/10.1148/ryct.2020200034). In late January, a Chinese team published a paper detailing the clinical and paraclinical features of COVID-19. They reported that patients present abnormalities in chest CT images with most having bilateral involvement [Huang 2020](). Bilateral multiple lobular and subsegmental areas of consolidation constitute the typical findings in chest CT images of intensive care unit (ICU) patients on admission [Huang 2020](). In comparison, non-ICU patients show bilateral ground-glass opacity and subsegmental areas of consolidation in their chest CT images [Huang 2020](). In these patients, later chest CT images display bilateral ground-glass opacity with resolved consolidation [Huang 2020]().
+
+### Statistic Information
+
+| Dataset Name | Anatomical Region | Task Type | Modality | Nnum. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release date | License |
+| ---------------------------------------------------------------------- | ----------------- | ------------ | -------- | ------------- | --------------------- | ---------------------- | ------------ | --------------------------------------------------------------------- |
+| [Covid-19-ct-cxr](https://github.com/ieee8023/covid-chestxray-dataset) | thorax | segmentation | x_ray | 2 | 205/-/714 | yes/-/no | 2021 | [CC-BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/) |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 205 | 72.84 | - | - | - | - |
+| lung | 205 | 27.16 | - | - | - | - |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![cov19ctcxr](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/x_ray/covid_19_ct_cxr/covid_19_ct_cxr_dataset.png?raw=true)
+
+### Dataset Citation
+
+```
+@article{cohen2020covidProspective,
+ title={{COVID-19} Image Data Collection: Prospective Predictions Are the Future},
+ author={Joseph Paul Cohen and Paul Morrison and Lan Dao and Karsten Roth and Tim Q Duong and Marzyeh Ghassemi},
+ journal={arXiv 2006.11988},
+ year={2020}
+}
+
+@article{cohen2020covid,
+ title={COVID-19 image data collection},
+ author={Joseph Paul Cohen and Paul Morrison and Lan Dao},
+ journal={arXiv 2003.11597},
+ year={2020}
+}
+```
+
+### Prerequisites
+
+- Python v3.8
+- PyTorch v1.10.0
+- pillow(PIL) v9.3.0 9.3.0
+- scikit-learn(sklearn) v1.2.0 1.2.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `covid_19_ct_cxr/` root directory, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset Preparing
+
+- download dataset from [here](https://github.com/ieee8023/covid-chestxray-dataset) and decompress data to path `'data/'`.
+- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below.
+- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set cannot be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
+
+```shell
+mkdir data && cd data
+git clone git@github.com:ieee8023/covid-chestxray-dataset.git
+cd ..
+python tools/prepare_dataset.py
+python ../../tools/split_seg_dataset.py
+```
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── x_ray
+ │ │ │ │ ├── covid_19_ct_cxr
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── val.txt
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+```
+
+### Divided Dataset Information
+
+***Note: The table information below is divided by ourselves.***
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 164 | 72.88 | 41 | 72.69 | - | - |
+| lung | 164 | 27.12 | 41 | 27.31 | - | - |
+
+### Training commands
+
+To train models on a single server with one GPU. (default)
+
+```shell
+mim train mmseg ./configs/${CONFIG_FILE}
+```
+
+### Testing commands
+
+To test models on a single server with one GPU. (default)
+
+```shell
+mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH}
+```
+
+
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+ - [x] Basic docstrings & proper citation
+ - [x] Test-time correctness
+ - [x] A full README
+
+- [x] Milestone 2: Indicates a successful model implementation.
+
+ - [x] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+ - [ ] Unit tests
+ - [ ] Code polishing
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/x_ray/covid_19_ct_cxr/configs/covid-19-ct-cxr_512x512.py b/projects/medical/2d_image/x_ray/covid_19_ct_cxr/configs/covid-19-ct-cxr_512x512.py
new file mode 100644
index 00000000000..5242d06c374
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/covid_19_ct_cxr/configs/covid-19-ct-cxr_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'Covid19CXRDataset'
+data_root = 'data/'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='val.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/x_ray/covid_19_ct_cxr/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_covid-19-ct-cxr-512x512.py b/projects/medical/2d_image/x_ray/covid_19_ct_cxr/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_covid-19-ct-cxr-512x512.py
new file mode 100644
index 00000000000..59a7bedaa06
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/covid_19_ct_cxr/configs/fcn-unet-s5-d16_unet-{use-sigmoid}_1xb16-0.01-20k_covid-19-ct-cxr-512x512.py
@@ -0,0 +1,18 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './covid-19-ct-cxr_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.covid-19-ct-cxr_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(
+ num_classes=2, loss_decode=dict(use_sigmoid=True), out_channels=1),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/x_ray/covid_19_ct_cxr/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_covid-19-ct-cxr-512x512.py b/projects/medical/2d_image/x_ray/covid_19_ct_cxr/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_covid-19-ct-cxr-512x512.py
new file mode 100644
index 00000000000..83b8527d46d
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/covid_19_ct_cxr/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_covid-19-ct-cxr-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './covid-19-ct-cxr_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.covid-19-ct-cxr_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/x_ray/covid_19_ct_cxr/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_covid-19-ct-cxr-512x512.py b/projects/medical/2d_image/x_ray/covid_19_ct_cxr/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_covid-19-ct-cxr-512x512.py
new file mode 100644
index 00000000000..10cfcbda6ec
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/covid_19_ct_cxr/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_covid-19-ct-cxr-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './covid-19-ct-cxr_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.covid-19-ct-cxr_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/x_ray/covid_19_ct_cxr/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_covid-19-ct-cxr-512x512.py b/projects/medical/2d_image/x_ray/covid_19_ct_cxr/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_covid-19-ct-cxr-512x512.py
new file mode 100644
index 00000000000..aaccc8fd8dd
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/covid_19_ct_cxr/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_covid-19-ct-cxr-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './covid-19-ct-cxr_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.covid-19-ct-cxr_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/x_ray/covid_19_ct_cxr/datasets/covid-19-ct-cxr_dataset.py b/projects/medical/2d_image/x_ray/covid_19_ct_cxr/datasets/covid-19-ct-cxr_dataset.py
new file mode 100644
index 00000000000..68a1bb331f7
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/covid_19_ct_cxr/datasets/covid-19-ct-cxr_dataset.py
@@ -0,0 +1,31 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class Covid19CXRDataset(BaseSegDataset):
+ """Covid19CXRDataset dataset.
+
+ In segmentation map annotation for Covid19CXRDataset,
+ 0 stands for background, which is included in 2 categories.
+ ``reduce_zero_label`` is fixed to False. The ``img_suffix``
+ is fixed to '.png' and ``seg_map_suffix`` is fixed to '.png'.
+
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
+ Default to False.
+ """
+ METAINFO = dict(classes=('background', 'lung'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
diff --git a/projects/medical/2d_image/x_ray/covid_19_ct_cxr/tools/prepare_dataset.py b/projects/medical/2d_image/x_ray/covid_19_ct_cxr/tools/prepare_dataset.py
new file mode 100644
index 00000000000..72f6435389d
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/covid_19_ct_cxr/tools/prepare_dataset.py
@@ -0,0 +1,52 @@
+import os
+
+import numpy as np
+from PIL import Image
+
+root_path = 'data/'
+src_img_dir = os.path.join(root_path, 'covid-chestxray-dataset', 'images')
+src_mask_dir = os.path.join(root_path, 'covid-chestxray-dataset',
+ 'annotations/lungVAE-masks')
+tgt_img_train_dir = os.path.join(root_path, 'images/train/')
+tgt_mask_train_dir = os.path.join(root_path, 'masks/train/')
+tgt_img_test_dir = os.path.join(root_path, 'images/test/')
+os.system('mkdir -p ' + tgt_img_train_dir)
+os.system('mkdir -p ' + tgt_mask_train_dir)
+os.system('mkdir -p ' + tgt_img_test_dir)
+
+
+def convert_label(img, convert_dict):
+ arr = np.zeros_like(img, dtype=np.uint8)
+ for c, i in convert_dict.items():
+ arr[img == c] = i
+ return arr
+
+
+if __name__ == '__main__':
+
+ all_img_names = os.listdir(src_img_dir)
+ all_mask_names = os.listdir(src_mask_dir)
+
+ for img_name in all_img_names:
+ base_name = img_name.replace('.png', '')
+ base_name = base_name.replace('.jpg', '')
+ base_name = base_name.replace('.jpeg', '')
+ mask_name_orig = base_name + '_mask.png'
+ if mask_name_orig in all_mask_names:
+ mask_name = base_name + '.png'
+ src_img_path = os.path.join(src_img_dir, img_name)
+ src_mask_path = os.path.join(src_mask_dir, mask_name_orig)
+ tgt_img_path = os.path.join(tgt_img_train_dir, img_name)
+ tgt_mask_path = os.path.join(tgt_mask_train_dir, mask_name)
+
+ img = Image.open(src_img_path).convert('RGB')
+ img.save(tgt_img_path)
+ mask = np.array(Image.open(src_mask_path))
+ mask = convert_label(mask, {0: 0, 255: 1})
+ mask = Image.fromarray(mask)
+ mask.save(tgt_mask_path)
+ else:
+ src_img_path = os.path.join(src_img_dir, img_name)
+ tgt_img_path = os.path.join(tgt_img_test_dir, img_name)
+ img = Image.open(src_img_path).convert('RGB')
+ img.save(tgt_img_path)
diff --git a/projects/medical/2d_image/x_ray/crass/README.md b/projects/medical/2d_image/x_ray/crass/README.md
new file mode 100644
index 00000000000..0621205be81
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/crass/README.md
@@ -0,0 +1,144 @@
+# Chest Radiograph Anatomical Structure Segmentation (CRASS)
+
+## Description
+
+This project supports **`Chest Radiograph Anatomical Structure Segmentation (CRASS) `**, which can be downloaded from [here](https://crass.grand-challenge.org/).
+
+### Dataset Overview
+
+A set of consecutively obtained posterior-anterior chest radiograph were selected from a database containing images acquired at two sites in sub Saharan Africa with a high tuberculosis incidence. All subjects were 15 years or older. Images from digital chest radiography units were used (Delft Imaging Systems, The Netherlands) of varying resolutions, with a typical resolution of 1800--2000 pixels, the pixel size was 250 lm isotropic. From the total set of images, 225 were considered to be normal by an expert radiologist, while 333 of the images contained abnormalities. Of the abnormal images, 220 contained abnormalities in the upper area of the lung where the clavicle is located. The data was divided into a training and a test set. The training set consisted of 299 images, the test set of 249 images.
+The current data is still incomplete and to be added later.
+
+### Information Statistics
+
+| Dataset Name | Anatomical Region | Task Type | Modality | Num. Classes | Train/Val/Test Images | Train/Val/Test Labeled | Release Date | License |
+| ------------------------------------------- | ----------------- | ------------ | -------- | ------------ | --------------------- | ---------------------- | ------------ | ------------------------------------------------------------- |
+| [crass](https://crass.grand-challenge.org/) | pulmonary | segmentation | x_ray | 2 | 299/-/234 | yes/-/no | 2021 | [CC0 1.0](https://creativecommons.org/publicdomain/zero/1.0/) |
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 299 | 98.38 | - | - | - | - |
+| clavicles | 299 | 1.62 | - | - | - | - |
+
+Note:
+
+- `Pct` means percentage of pixels in this category in all pixels.
+
+### Visualization
+
+![crass](https://raw.githubusercontent.com/uni-medical/medical-datasets-visualization/main/2d/semantic_seg/x_ray/crass/crass_dataset.png?raw=true)
+
+### Dataset Citation
+
+```
+@article{HOGEWEG20121490,
+ title={Clavicle segmentation in chest radiographs},
+ journal={Medical Image Analysis},
+ volume={16},
+ number={8},
+ pages={1490-1502},
+ year={2012}
+}
+```
+
+### Prerequisites
+
+- Python v3.8
+- PyTorch v1.10.0
+- pillow(PIL) v9.3.0
+- scikit-learn(sklearn) v1.2.0
+- [MIM](https://github.com/open-mmlab/mim) v0.3.4
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.2.0 or higher
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc5
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `crass/` root directory, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Dataset Preparing
+
+- download dataset from [here](https://crass.grand-challenge.org/) and decompress data to path `'data/'`.
+- run script `"python tools/prepare_dataset.py"` to format data and change folder structure as below.
+- run script `"python ../../tools/split_seg_dataset.py"` to split dataset and generate `train.txt`, `val.txt` and `test.txt`. If the label of official validation set and test set cannot be obtained, we generate `train.txt` and `val.txt` from the training set randomly.
+
+```none
+ mmsegmentation
+ ├── mmseg
+ ├── projects
+ │ ├── medical
+ │ │ ├── 2d_image
+ │ │ │ ├── x_ray
+ │ │ │ │ ├── crass
+ │ │ │ │ │ ├── configs
+ │ │ │ │ │ ├── datasets
+ │ │ │ │ │ ├── tools
+ │ │ │ │ │ ├── data
+ │ │ │ │ │ │ ├── train.txt
+ │ │ │ │ │ │ ├── val.txt
+ │ │ │ │ │ │ ├── images
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+ │ │ │ │ │ │ ├── masks
+ │ │ │ │ │ │ │ ├── train
+ │ │ │ │ | │ │ │ ├── xxx.png
+ │ │ │ │ | │ │ │ ├── ...
+ │ │ │ │ | │ │ │ └── xxx.png
+```
+
+### Divided Dataset Information
+
+***Note: The table information below is divided by ourselves.***
+
+| Class Name | Num. Train | Pct. Train | Num. Val | Pct. Val | Num. Test | Pct. Test |
+| :--------: | :--------: | :--------: | :------: | :------: | :-------: | :-------: |
+| background | 227 | 98.38 | 57 | 98.39 | - | - |
+| clavicles | 227 | 1.62 | 57 | 1.61 | - | - |
+
+### Training commands
+
+To train models on a single server with one GPU. (default)
+
+```shell
+mim train mmseg ./configs/${CONFIG_FILE}
+```
+
+### Testing commands
+
+To test models on a single server with one GPU. (default)
+
+```shell
+mim test mmseg ./configs/${CONFIG_FILE} --checkpoint ${CHECKPOINT_PATH}
+```
+
+
+
+## Checklist
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+ - [x] Basic docstrings & proper citation
+ - [ ] Test-time correctness
+ - [x] A full README
+
+- [x] Milestone 2: Indicates a successful model implementation.
+
+ - [x] Training-time correctness
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+ - [ ] Unit tests
+ - [ ] Code polishing
+ - [ ] Metafile.yml
+
+- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
+
+- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/medical/2d_image/x_ray/crass/configs/crass_512x512.py b/projects/medical/2d_image/x_ray/crass/configs/crass_512x512.py
new file mode 100644
index 00000000000..1425f50cc46
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/crass/configs/crass_512x512.py
@@ -0,0 +1,42 @@
+dataset_type = 'CRASSDataset'
+data_root = 'data/'
+img_scale = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=img_scale, keep_ratio=False),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+]
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='train.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='tval.txt',
+ data_prefix=dict(img_path='images/', seg_map_path='masks/'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
+test_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice'])
diff --git a/projects/medical/2d_image/x_ray/crass/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_crass-512x512.py b/projects/medical/2d_image/x_ray/crass/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_crass-512x512.py
new file mode 100644
index 00000000000..b52bc78f790
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/crass/configs/fcn-unet-s5-d16_unet_1xb16-0.0001-20k_crass-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './crass_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.crass_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.0001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/x_ray/crass/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_crass-512x512.py b/projects/medical/2d_image/x_ray/crass/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_crass-512x512.py
new file mode 100644
index 00000000000..45242c65b48
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/crass/configs/fcn-unet-s5-d16_unet_1xb16-0.001-20k_crass-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './crass_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.crass_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.001)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/x_ray/crass/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_crass-512x512.py b/projects/medical/2d_image/x_ray/crass/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_crass-512x512.py
new file mode 100644
index 00000000000..bcf9d0a5ca5
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/crass/configs/fcn-unet-s5-d16_unet_1xb16-0.01-20k_crass-512x512.py
@@ -0,0 +1,17 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './crass_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.crass_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(num_classes=2),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/x_ray/crass/configs/fcn-unet-s5-d16_unet_1xb16-lr0.01-sigmoid-20k_crass-512x512.py b/projects/medical/2d_image/x_ray/crass/configs/fcn-unet-s5-d16_unet_1xb16-lr0.01-sigmoid-20k_crass-512x512.py
new file mode 100644
index 00000000000..0dde736bf76
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/crass/configs/fcn-unet-s5-d16_unet_1xb16-lr0.01-sigmoid-20k_crass-512x512.py
@@ -0,0 +1,18 @@
+_base_ = [
+ 'mmseg::_base_/models/fcn_unet_s5-d16.py', './crass_512x512.py',
+ 'mmseg::_base_/default_runtime.py',
+ 'mmseg::_base_/schedules/schedule_20k.py'
+]
+custom_imports = dict(imports='datasets.crass_dataset')
+img_scale = (512, 512)
+data_preprocessor = dict(size=img_scale)
+optimizer = dict(lr=0.01)
+optim_wrapper = dict(optimizer=optimizer)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ decode_head=dict(
+ num_classes=2, loss_decode=dict(use_sigmoid=True), out_channels=1),
+ auxiliary_head=None,
+ test_cfg=dict(mode='whole', _delete_=True))
+vis_backends = None
+visualizer = dict(vis_backends=vis_backends)
diff --git a/projects/medical/2d_image/x_ray/crass/datasets/crass_dataset.py b/projects/medical/2d_image/x_ray/crass/datasets/crass_dataset.py
new file mode 100644
index 00000000000..f6b5c5228bc
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/crass/datasets/crass_dataset.py
@@ -0,0 +1,30 @@
+from mmseg.datasets import BaseSegDataset
+from mmseg.registry import DATASETS
+
+
+@DATASETS.register_module()
+class CRASSDataset(BaseSegDataset):
+ """CRASSDataset dataset.
+
+ In segmentation map annotation for CRASSDataset, 0 stands for background,
+ which is included in 2 categories. ``reduce_zero_label`` is fixed to
+ False. The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is
+ fixed to '.png'.
+ Args:
+ img_suffix (str): Suffix of images. Default: '.png'
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
+ Default to False..
+ """
+ METAINFO = dict(classes=('background', 'clavicles'))
+
+ def __init__(self,
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs) -> None:
+ super().__init__(
+ img_suffix=img_suffix,
+ seg_map_suffix=seg_map_suffix,
+ reduce_zero_label=reduce_zero_label,
+ **kwargs)
diff --git a/projects/medical/2d_image/x_ray/crass/tools/prepare_dataset.py b/projects/medical/2d_image/x_ray/crass/tools/prepare_dataset.py
new file mode 100644
index 00000000000..bbd5d8891d4
--- /dev/null
+++ b/projects/medical/2d_image/x_ray/crass/tools/prepare_dataset.py
@@ -0,0 +1,84 @@
+import glob
+import os
+
+import cv2
+import SimpleITK as sitk
+from PIL import Image
+
+root_path = 'data/'
+img_suffix = '.tif'
+seg_map_suffix = '.png'
+save_img_suffix = '.png'
+save_seg_map_suffix = '.png'
+
+src_img_train_dir = os.path.join(root_path, 'CRASS/data_train')
+src_mask_train_dir = os.path.join(root_path, 'CRASS/mask_mhd')
+src_img_test_dir = os.path.join(root_path, 'CRASS/data_test')
+
+tgt_img_train_dir = os.path.join(root_path, 'images/train/')
+tgt_mask_train_dir = os.path.join(root_path, 'masks/train/')
+tgt_img_test_dir = os.path.join(root_path, 'images/test/')
+os.system('mkdir -p ' + tgt_img_train_dir)
+os.system('mkdir -p ' + tgt_mask_train_dir)
+os.system('mkdir -p ' + tgt_img_test_dir)
+
+
+def filter_suffix_recursive(src_dir, suffix):
+ suffix = '.' + suffix if '.' not in suffix else suffix
+ file_paths = glob(
+ os.path.join(src_dir, '**', '*' + suffix), recursive=True)
+ file_names = [_.split('/')[-1] for _ in file_paths]
+ return sorted(file_paths), sorted(file_names)
+
+
+def read_single_array_from_med(path):
+ return sitk.GetArrayFromImage(sitk.ReadImage(path)).squeeze()
+
+
+def convert_meds_into_pngs(src_dir,
+ tgt_dir,
+ suffix='.dcm',
+ norm_min=0,
+ norm_max=255,
+ convert='RGB'):
+ if not os.path.exists(tgt_dir):
+ os.makedirs(tgt_dir)
+
+ src_paths, src_names = filter_suffix_recursive(src_dir, suffix=suffix)
+ num = len(src_paths)
+ for i, (src_name, src_path) in enumerate(zip(src_names, src_paths)):
+ tgt_name = src_name.replace(suffix, '.png')
+ tgt_path = os.path.join(tgt_dir, tgt_name)
+
+ img = read_single_array_from_med(src_path)
+ if norm_min is not None and norm_max is not None:
+ img = cv2.normalize(img, None, norm_min, norm_max, cv2.NORM_MINMAX,
+ cv2.CV_8U)
+ pil = Image.fromarray(img).convert(convert)
+ pil.save(tgt_path)
+ print(f'processed {i+1}/{num}.')
+
+
+convert_meds_into_pngs(
+ src_img_train_dir,
+ tgt_img_train_dir,
+ suffix='.mhd',
+ norm_min=0,
+ norm_max=255,
+ convert='RGB')
+
+convert_meds_into_pngs(
+ src_img_test_dir,
+ tgt_img_test_dir,
+ suffix='.mhd',
+ norm_min=0,
+ norm_max=255,
+ convert='RGB')
+
+convert_meds_into_pngs(
+ src_mask_train_dir,
+ tgt_mask_train_dir,
+ suffix='.mhd',
+ norm_min=0,
+ norm_max=1,
+ convert='L')
diff --git a/projects/pp_mobileseg/README.md b/projects/pp_mobileseg/README.md
new file mode 100644
index 00000000000..c9f9c128e74
--- /dev/null
+++ b/projects/pp_mobileseg/README.md
@@ -0,0 +1,123 @@
+# PP-MobileSeg: Exploring Transformer Blocks for Efficient Mobile Segmentation.
+
+## Reference
+
+> [PP-MobileSeg: Explore the Fast and Accurate Semantic Segmentation Model on Mobile Devices. ](https://arxiv.org/abs/2304.05152)
+
+## Introduction
+
+Official Repo
+
+Code Snippet
+
+## Abstract
+
+With the success of transformers in computer vision, several attempts have been made to adapt transformers to mobile devices. However, their performance is not satisfied for some real world applications. Therefore, we propose PP-MobileSeg, a SOTA semantic segmentation model for mobile devices.
+
+It is composed of three newly proposed parts, the strideformer backbone, the Aggregated Attention Module(AAM), and the Valid Interpolate Module(VIM):
+
+- With the four-stage MobileNetV3 block as the feature extractor, we manage to extract rich local features of different receptive fields with little parameter overhead. Also, we further efficiently empower features from the last two stages with the global view using strided sea attention.
+- To effectively fuse the features, we use AAM to filter the detail features with ensemble voting and add the semantic feature to it to enhance the semantic information to the most content.
+- At last, we use VIM to upsample the downsampled feature to the original resolution and significantly decrease latency in model inference stage. It only interpolates classes present in the final prediction which only takes around 10% in the ADE20K dataset. This is a common scenario for datasets with large classes. Therefore it significantly decreases the latency of the final upsample process which takes the greatest part of the model's overall latency.
+
+Extensive experiments show that PP-MobileSeg achieves a superior params-accuracy-latency tradeoff compared to other SOTA methods.
+
+
+
+
+
+## Performance
+
+### ADE20K
+
+| Model | Backbone | Training Iters | Batchsize | Train Resolution | mIoU(%) | latency(ms)\* | params(M) | config | Links |
+| ----------------- | ----------------- | -------------- | --------- | ---------------- | ------- | ------------- | --------- | ------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
+| PP-MobileSeg-Base | StrideFormer-Base | 80000 | 32 | 512x512 | 41.57% | 265.5 | 5.62 | [config](https://github.com/Yang-Changhui/mmsegmentation/tree/add_ppmobileseg/projects/pp_mobileseg/configs/pp_mobileseg) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pp_mobileseg/pp_mobileseg_mobilenetv3_2xb16_3rdparty-base_512x512-ade20k-f12b44f3.pth)\|[log](https://bj.bcebos.com/paddleseg/dygraph/ade20k/pp_mobileseg_base/train.log) |
+| PP-MobileSeg-Tiny | StrideFormer-Tiny | 80000 | 32 | 512x512 | 36.39% | 215.3 | 1.61 | [config](https://github.com/Yang-Changhui/mmsegmentation/tree/add_ppmobileseg/projects/pp_mobileseg/configs/pp_mobileseg) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pp_mobileseg/pp_mobileseg_mobilenetv3_2xb16_3rdparty-tiny_512x512-ade20k-a351ebf5.pth)\|[log](https://bj.bcebos.com/paddleseg/dygraph/ade20k/pp_mobileseg_tiny/train.log) |
+
+## Usage
+
+Same as other models in MMsegmentation, you can run the following command to test the model at ${MMSEG_ROOT}:
+
+```shell
+./tools/dist_test.sh projects/pp_mobileseg/configs/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_base.py checkpoints/pp_mobileseg_mobilenetv3_2xb16_3rdparty-base_512x512-ade20k-f12b44f3.pth 8
+```
+
+## Inference with ONNXRuntime
+
+### Prerequisites
+
+**1. Install onnxruntime inference engine.**
+
+Choose one of the following ways to install onnxruntime.
+
+- CPU version
+
+```shell
+pip install onnxruntime==1.15.1
+wget https://github.com/microsoft/onnxruntime/releases/download/v1.15.1/onnxruntime-linux-x64-1.15.1.tgz
+tar -zxvf onnxruntime-linux-x64-1.15.1.tgz
+export ONNXRUNTIME_DIR=$(pwd)/onnxruntime-linux-x64-1.15.1
+export LD_LIBRARY_PATH=$ONNXRUNTIME_DIR/lib:$LD_LIBRARY_PATH
+```
+
+**2. Convert model to onnx file**
+
+- Install `mim` and `mmdeploy`.
+
+```shell
+pip install openmim
+mim install mmdeploy
+git clone https://github.com/open-mmlab/mmdeploy.git
+```
+
+- Download pp_mobileseg model.
+
+```shell
+wget https://download.openmmlab.com/mmsegmentation/v0.5/pp_mobileseg/pp_mobileseg_mobilenetv3_2xb16_3rdparty-tiny_512x512-ade20k-a351ebf5.pth
+```
+
+- Convert model to onnx files.
+
+```shell
+python mmdeploy/tools/deploy.py mmdeploy/configs/mmseg/segmentation_onnxruntime_dynamic.py \
+ configs/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_tiny.py \
+ pp_mobileseg_mobilenetv3_2xb16_3rdparty-tiny_512x512-ade20k-a351ebf5.pth \
+ ../../demo/demo.png \
+ --work-dir mmdeploy_model/mmseg/ort \
+ --show
+```
+
+**3. Run demo**
+
+```shell
+python inference_onnx.py ${ONNX_FILE_PATH} ${IMAGE_PATH} [${MODEL_INPUT_SIZE} ${DEVICE} ${OUTPUT_IMAGE_PATH}]
+```
+
+Example:
+
+```shell
+python inference_onnx.py mmdeploy_model/mmseg/ort/end2end.onnx ../../demo/demo.png
+```
+
+## Citation
+
+If you find our project useful in your research, please consider citing:
+
+```
+@misc{liu2021paddleseg,
+ title={PaddleSeg: A High-Efficient Development Toolkit for Image Segmentation},
+ author={Yi Liu and Lutao Chu and Guowei Chen and Zewu Wu and Zeyu Chen and Baohua Lai and Yuying Hao},
+ year={2021},
+ eprint={2101.06175},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+
+@misc{paddleseg2019,
+ title={PaddleSeg, End-to-end image segmentation kit based on PaddlePaddle},
+ author={PaddlePaddle Contributors},
+ howpublished = {\url{https://github.com/PaddlePaddle/PaddleSeg}},
+ year={2019}
+}
+```
diff --git a/projects/pp_mobileseg/backbones/__init__.py b/projects/pp_mobileseg/backbones/__init__.py
new file mode 100644
index 00000000000..244b33d37a8
--- /dev/null
+++ b/projects/pp_mobileseg/backbones/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .strideformer import StrideFormer
+
+__all__ = ['StrideFormer']
diff --git a/projects/pp_mobileseg/backbones/strideformer.py b/projects/pp_mobileseg/backbones/strideformer.py
new file mode 100644
index 00000000000..3f09be5225d
--- /dev/null
+++ b/projects/pp_mobileseg/backbones/strideformer.py
@@ -0,0 +1,958 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, build_activation_layer
+from mmcv.cnn.bricks.transformer import build_dropout
+from mmengine.logging import print_log
+from mmengine.model import BaseModule
+from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
+
+from mmseg.registry import MODELS
+
+
+@MODELS.register_module()
+class StrideFormer(BaseModule):
+ """The StrideFormer implementation based on torch.
+
+ The original article refers to:https://arxiv.org/abs/2304.05152
+ Args:
+ mobileV3_cfg(list): Each sublist describe the config for a
+ MobileNetV3 block.
+ channels(list): The input channels for each MobileNetV3 block.
+ embed_dims(list): The channels of the features input to the sea
+ attention block.
+ key_dims(list, optional): The embeding dims for each head in
+ attention.
+ depths(list, optional): describes the depth of the attention block.
+ i,e: M,N.
+ num_heads(int, optional): The number of heads of the attention
+ blocks.
+ attn_ratios(int, optional): The expand ratio of V.
+ mlp_ratios(list, optional): The ratio of mlp blocks.
+ drop_path_rate(float, optional): The drop path rate in attention
+ block.
+ act_cfg(dict, optional): The activation layer of AAM:
+ Aggregate Attention Module.
+ inj_type(string, optional): The type of injection/AAM.
+ out_channels(int, optional): The output channels of the AAM.
+ dims(list, optional): The dimension of the fusion block.
+ out_feat_chs(list, optional): The input channels of the AAM.
+ stride_attention(bool, optional): whether to stride attention in
+ each attention layer.
+ pretrained(str, optional): the path of pretrained model.
+ """
+
+ def __init__(
+ self,
+ mobileV3_cfg,
+ channels,
+ embed_dims,
+ key_dims=[16, 24],
+ depths=[2, 2],
+ num_heads=8,
+ attn_ratios=2,
+ mlp_ratios=[2, 4],
+ drop_path_rate=0.1,
+ act_cfg=dict(type='ReLU'),
+ inj_type='AAM',
+ out_channels=256,
+ dims=(128, 160),
+ out_feat_chs=None,
+ stride_attention=True,
+ pretrained=None,
+ init_cfg=None,
+ ):
+ super().__init__(init_cfg=init_cfg)
+ assert not (init_cfg and pretrained
+ ), 'init_cfg and pretrained cannot be set at the same time'
+ if isinstance(pretrained, str):
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
+ 'please use "init_cfg" instead')
+ self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
+ elif pretrained is not None:
+ raise TypeError('pretrained must be a str or None')
+
+ self.depths = depths
+ self.cfgs = mobileV3_cfg
+ self.dims = dims
+ for i in range(len(self.cfgs)):
+ smb = StackedMV3Block(
+ cfgs=self.cfgs[i],
+ stem=True if i == 0 else False,
+ in_channels=channels[i],
+ )
+ setattr(self, f'smb{i + 1}', smb)
+ for i in range(len(depths)):
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, depths[i])
+ ]
+ trans = BasicLayer(
+ block_num=depths[i],
+ embedding_dim=embed_dims[i],
+ key_dim=key_dims[i],
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratios[i],
+ attn_ratio=attn_ratios,
+ drop=0,
+ attn_drop=0.0,
+ drop_path=dpr,
+ act_cfg=act_cfg,
+ stride_attention=stride_attention,
+ )
+ setattr(self, f'trans{i + 1}', trans)
+
+ self.inj_type = inj_type
+ if self.inj_type == 'AAM':
+ self.inj_module = InjectionMultiSumallmultiallsum(
+ in_channels=out_feat_chs, out_channels=out_channels)
+ self.feat_channels = [
+ out_channels,
+ ]
+ elif self.inj_type == 'AAMSx8':
+ self.inj_module = InjectionMultiSumallmultiallsumSimpx8(
+ in_channels=out_feat_chs, out_channels=out_channels)
+ self.feat_channels = [
+ out_channels,
+ ]
+ elif self.inj_type == 'origin':
+ for i in range(len(dims)):
+ fuse = FusionBlock(
+ out_feat_chs[0] if i == 0 else dims[i - 1],
+ out_feat_chs[i + 1],
+ embed_dim=dims[i],
+ act_cfg=None,
+ )
+ setattr(self, f'fuse{i + 1}', fuse)
+ self.feat_channels = [
+ dims[i],
+ ]
+ else:
+ raise NotImplementedError(self.inj_module + ' is not implemented')
+
+ self.pretrained = pretrained
+ # self.init_weights()
+
+ def init_weights(self):
+ if (isinstance(self.init_cfg, dict)
+ and self.init_cfg.get('type') == 'Pretrained'):
+ checkpoint = CheckpointLoader.load_checkpoint(
+ self.init_cfg['checkpoint'], logger=None, map_location='cpu')
+
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ else:
+ state_dict = checkpoint
+
+ if 'pos_embed' in state_dict.keys():
+ if self.pos_embed.shape != state_dict['pos_embed'].shape:
+ print_log(msg=f'Resize the pos_embed shape from '
+ f'{state_dict["pos_embed"].shape} to '
+ f'{self.pos_embed.shape}')
+ h, w = self.img_size
+ pos_size = int(
+ math.sqrt(state_dict['pos_embed'].shape[1] - 1))
+ state_dict['pos_embed'] = self.resize_pos_embed(
+ state_dict['pos_embed'],
+ (h // self.patch_size, w // self.patch_size),
+ (pos_size, pos_size),
+ self.interpolate_mode,
+ )
+
+ load_state_dict(self, state_dict, strict=False, logger=None)
+
+ def forward(self, x):
+ x_hw = x.shape[2:]
+ outputs = []
+ num_smb_stage = len(self.cfgs)
+ num_trans_stage = len(self.depths)
+
+ for i in range(num_smb_stage):
+ smb = getattr(self, f'smb{i + 1}')
+ x = smb(x)
+
+ # 1/8 shared feat
+ if i == 1:
+ outputs.append(x)
+ if num_trans_stage + i >= num_smb_stage:
+ trans = getattr(
+ self, f'trans{i + num_trans_stage - num_smb_stage + 1}')
+ x = trans(x)
+ outputs.append(x)
+ if self.inj_type == 'origin':
+ x_detail = outputs[0]
+ for i in range(len(self.dims)):
+ fuse = getattr(self, f'fuse{i + 1}')
+
+ x_detail = fuse(x_detail, outputs[i + 1])
+ output = x_detail
+ else:
+ output = self.inj_module(outputs)
+
+ return [output, x_hw]
+
+
+class StackedMV3Block(nn.Module):
+ """The MobileNetV3 block.
+
+ Args:
+ cfgs (list): The MobileNetV3 config list of a stage.
+ stem (bool): Whether is the first stage or not.
+ in_channels (int, optional): The channels of input image. Default: 3.
+ scale: float=1.0.
+ The coefficient that controls the size of network parameters.
+
+ Returns:
+ model: nn.Module.
+ A stage of specific MobileNetV3 model depends on args.
+ """
+
+ def __init__(self,
+ cfgs,
+ stem,
+ in_channels,
+ scale=1.0,
+ norm_cfg=dict(type='BN')):
+ super().__init__()
+
+ self.scale = scale
+ self.stem = stem
+
+ if self.stem:
+ self.conv = ConvModule(
+ in_channels=3,
+ out_channels=_make_divisible(in_channels * self.scale),
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ groups=1,
+ bias=False,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='HSwish'),
+ )
+
+ self.blocks = nn.ModuleList()
+ for i, (k, exp, c, se, act, s) in enumerate(cfgs):
+ self.blocks.append(
+ ResidualUnit(
+ in_channel=_make_divisible(in_channels * self.scale),
+ mid_channel=_make_divisible(self.scale * exp),
+ out_channel=_make_divisible(self.scale * c),
+ kernel_size=k,
+ stride=s,
+ use_se=se,
+ act=act,
+ dilation=1,
+ ))
+ in_channels = _make_divisible(self.scale * c)
+
+ def forward(self, x):
+ if self.stem:
+ x = self.conv(x)
+ for i, block in enumerate(self.blocks):
+ x = block(x)
+
+ return x
+
+
+class ResidualUnit(nn.Module):
+ """The Residual module.
+
+ Args:
+ in_channel (int, optional): The channels of input feature.
+ mid_channel (int, optional): The channels of middle process.
+ out_channel (int, optional): The channels of output feature.
+ kernel_size (int, optional): The size of the convolving kernel.
+ stride (int, optional): The stride size.
+ use_se (bool, optional): if to use the SEModule.
+ act (string, optional): activation layer.
+ dilation (int, optional): The dilation size.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN', requires_grad=True).
+ """
+
+ def __init__(
+ self,
+ in_channel,
+ mid_channel,
+ out_channel,
+ kernel_size,
+ stride,
+ use_se,
+ act=None,
+ dilation=1,
+ norm_cfg=dict(type='BN'),
+ ):
+ super().__init__()
+ self.if_shortcut = stride == 1 and in_channel == out_channel
+ self.if_se = use_se
+ self.expand_conv = ConvModule(
+ in_channels=in_channel,
+ out_channels=mid_channel,
+ kernel_size=1,
+ bias=False,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type=act) if act is not None else None,
+ )
+ self.bottleneck_conv = ConvModule(
+ in_channels=mid_channel,
+ out_channels=mid_channel,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=int((kernel_size - 1) // 2) * dilation,
+ bias=False,
+ groups=mid_channel,
+ dilation=dilation,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type=act) if act is not None else None,
+ )
+ if self.if_se:
+ self.mid_se = SEModule(mid_channel)
+ self.linear_conv = ConvModule(
+ in_channels=mid_channel,
+ out_channels=out_channel,
+ kernel_size=1,
+ bias=False,
+ norm_cfg=norm_cfg,
+ act_cfg=None,
+ )
+
+ def forward(self, x):
+ identity = x
+ x = self.expand_conv(x)
+ x = self.bottleneck_conv(x)
+ if self.if_se:
+ x = self.mid_se(x)
+ x = self.linear_conv(x)
+ if self.if_shortcut:
+ x = torch.add(identity, x)
+ return x
+
+
+class SEModule(nn.Module):
+ """SE Module.
+
+ Args:
+ channel (int, optional): The channels of input feature.
+ reduction (int, optional): The channel reduction rate.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='PReLU').
+ """
+
+ def __init__(self, channel, reduction=4, act_cfg=dict(type='ReLU')):
+ super().__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.conv_act1 = ConvModule(
+ in_channels=channel,
+ out_channels=channel // reduction,
+ kernel_size=1,
+ norm_cfg=None,
+ act_cfg=act_cfg,
+ )
+
+ self.conv_act2 = ConvModule(
+ in_channels=channel // reduction,
+ out_channels=channel,
+ kernel_size=1,
+ norm_cfg=None,
+ act_cfg=dict(type='Hardsigmoid', slope=0.2, offset=0.5),
+ )
+
+ def forward(self, x):
+ identity = x
+ x = self.avg_pool(x)
+ x = self.conv_act1(x)
+ x = self.conv_act2(x)
+ return torch.mul(identity, x)
+
+
+class BasicLayer(nn.Module):
+ """The transformer basic layer.
+
+ Args:
+ block_num (int): the block nums of the transformer basic layer.
+ embedding_dim (int): The feature dimension.
+ key_dim (int): the key dim.
+ num_heads (int): Parallel attention heads.
+ mlp_ratio (float): the mlp ratio.
+ attn_ratio (float): the attention ratio.
+ drop (float): Probability of an element to be zeroed
+ after the feed forward layer.Default: 0.0.
+ attn_drop (float): The drop out rate for attention layer.
+ Default: 0.0.
+ drop_path (float): stochastic depth rate. Default 0.0.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='PReLU').
+ stride_attention (bool, optional): whether to stride attention in
+ each attention layer.
+ """
+
+ def __init__(
+ self,
+ block_num,
+ embedding_dim,
+ key_dim,
+ num_heads,
+ mlp_ratio=4.0,
+ attn_ratio=2.0,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=None,
+ act_cfg=None,
+ stride_attention=None,
+ ):
+ super().__init__()
+ self.block_num = block_num
+
+ self.transformer_blocks = nn.ModuleList()
+ for i in range(self.block_num):
+ self.transformer_blocks.append(
+ Block(
+ embedding_dim,
+ key_dim=key_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ attn_ratio=attn_ratio,
+ drop=drop,
+ drop_path=drop_path[i]
+ if isinstance(drop_path, list) else drop_path,
+ act_cfg=act_cfg,
+ stride_attention=stride_attention,
+ ))
+
+ def forward(self, x):
+ for i in range(self.block_num):
+ x = self.transformer_blocks[i](x)
+ return x
+
+
+class Block(nn.Module):
+ """the block of the transformer basic layer.
+
+ Args:
+ dim (int): The feature dimension.
+ key_dim (int): The key dimension.
+ num_heads (int): Parallel attention heads.
+ mlp_ratio (float): the mlp ratio.
+ attn_ratio (float): the attention ratio.
+ drop (float): Probability of an element to be zeroed
+ after the feed forward layer.Default: 0.0.
+ drop_path (float): stochastic depth rate. Default 0.0.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='PReLU').
+ stride_attention (bool, optional): whether to stride attention in
+ each attention layer.
+ """
+
+ def __init__(
+ self,
+ dim,
+ key_dim,
+ num_heads,
+ mlp_ratio=4.0,
+ attn_ratio=2.0,
+ drop=0.0,
+ drop_path=0.0,
+ act_cfg=None,
+ stride_attention=None,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.mlp_ratio = mlp_ratio
+ self.attn = SeaAttention(
+ dim,
+ key_dim=key_dim,
+ num_heads=num_heads,
+ attn_ratio=attn_ratio,
+ act_cfg=act_cfg,
+ stride_attention=stride_attention,
+ )
+ self.drop_path = (
+ build_dropout(dict(type='DropPath', drop_prob=drop_path))
+ if drop_path > 0.0 else nn.Identity())
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = MLP(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_cfg=act_cfg,
+ drop=drop,
+ )
+
+ def forward(self, x1):
+ x1 = x1 + self.drop_path(self.attn(x1))
+ x1 = x1 + self.drop_path(self.mlp(x1))
+
+ return x1
+
+
+class SqueezeAxialPositionalEmbedding(nn.Module):
+ """the Squeeze Axial Positional Embedding.
+
+ Args:
+ dim (int): The feature dimension.
+ shape (int): The patch size.
+ """
+
+ def __init__(self, dim, shape):
+ super().__init__()
+ self.pos_embed = nn.init.normal_(
+ nn.Parameter(torch.zeros(1, dim, shape)))
+
+ def forward(self, x):
+ B, C, N = x.shape
+ x = x + F.interpolate(
+ self.pos_embed, size=(N, ), mode='linear', align_corners=False)
+ return x
+
+
+class SeaAttention(nn.Module):
+ """The sea attention.
+
+ Args:
+ dim (int): The feature dimension.
+ key_dim (int): The key dimension.
+ num_heads (int): number of attention heads.
+ attn_ratio (float): the attention ratio.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='PReLU').
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN')
+ stride_attention (bool, optional): whether to stride attention in
+ each attention layer.
+ """
+
+ def __init__(
+ self,
+ dim,
+ key_dim,
+ num_heads,
+ attn_ratio=4.0,
+ act_cfg=None,
+ norm_cfg=dict(type='BN'),
+ stride_attention=False,
+ ):
+
+ super().__init__()
+ self.num_heads = num_heads
+ self.scale = key_dim**-0.5
+ self.nh_kd = nh_kd = key_dim * num_heads
+ self.d = int(attn_ratio * key_dim)
+ self.dh = int(attn_ratio * key_dim) * num_heads
+ self.attn_ratio = attn_ratio
+
+ self.to_q = ConvModule(
+ dim, nh_kd, 1, bias=False, norm_cfg=norm_cfg, act_cfg=None)
+ self.to_k = ConvModule(
+ dim, nh_kd, 1, bias=False, norm_cfg=norm_cfg, act_cfg=None)
+
+ self.to_v = ConvModule(
+ dim, self.dh, 1, bias=False, norm_cfg=norm_cfg, act_cfg=None)
+ self.stride_attention = stride_attention
+ if self.stride_attention:
+ self.stride_conv = ConvModule(
+ dim,
+ dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=True,
+ groups=dim,
+ norm_cfg=norm_cfg,
+ act_cfg=None,
+ )
+
+ self.proj = ConvModule(
+ self.dh,
+ dim,
+ 1,
+ bias=False,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ order=('act', 'conv', 'norm'),
+ )
+ self.proj_encode_row = ConvModule(
+ self.dh,
+ self.dh,
+ 1,
+ bias=False,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ order=('act', 'conv', 'norm'),
+ )
+ self.pos_emb_rowq = SqueezeAxialPositionalEmbedding(nh_kd, 16)
+ self.pos_emb_rowk = SqueezeAxialPositionalEmbedding(nh_kd, 16)
+ self.proj_encode_column = ConvModule(
+ self.dh,
+ self.dh,
+ 1,
+ bias=False,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ order=('act', 'conv', 'norm'),
+ )
+ self.pos_emb_columnq = SqueezeAxialPositionalEmbedding(nh_kd, 16)
+ self.pos_emb_columnk = SqueezeAxialPositionalEmbedding(nh_kd, 16)
+ self.dwconv = ConvModule(
+ 2 * self.dh,
+ 2 * self.dh,
+ 3,
+ padding=1,
+ groups=2 * self.dh,
+ bias=False,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ )
+ self.pwconv = ConvModule(
+ 2 * self.dh, dim, 1, bias=False, norm_cfg=norm_cfg, act_cfg=None)
+ self.sigmoid = build_activation_layer(dict(type='HSigmoid'))
+
+ def forward(self, x):
+ B, C, H_ori, W_ori = x.shape
+ if self.stride_attention:
+ x = self.stride_conv(x)
+ B, C, H, W = x.shape
+
+ q = self.to_q(x) # [B, nhead*dim, H, W]
+ k = self.to_k(x)
+ v = self.to_v(x)
+
+ qkv = torch.cat([q, k, v], dim=1)
+ qkv = self.dwconv(qkv)
+ qkv = self.pwconv(qkv)
+
+ qrow = (self.pos_emb_rowq(q.mean(-1)).reshape(
+ [B, self.num_heads, -1, H]).permute(
+ (0, 1, 3, 2))) # [B, nhead, H, dim]
+ krow = self.pos_emb_rowk(k.mean(-1)).reshape(
+ [B, self.num_heads, -1, H]) # [B, nhead, dim, H]
+ vrow = (v.mean(-1).reshape([B, self.num_heads, -1,
+ H]).permute([0, 1, 3, 2])
+ ) # [B, nhead, H, dim*attn_ratio]
+
+ attn_row = torch.matmul(qrow, krow) * self.scale # [B, nhead, H, H]
+ attn_row = nn.functional.softmax(attn_row, dim=-1)
+
+ xx_row = torch.matmul(attn_row, vrow) # [B, nhead, H, dim*attn_ratio]
+ xx_row = self.proj_encode_row(
+ xx_row.permute([0, 1, 3, 2]).reshape([B, self.dh, H, 1]))
+
+ # squeeze column
+ qcolumn = (
+ self.pos_emb_columnq(q.mean(-2)).reshape(
+ [B, self.num_heads, -1, W]).permute([0, 1, 3, 2]))
+ kcolumn = self.pos_emb_columnk(k.mean(-2)).reshape(
+ [B, self.num_heads, -1, W])
+ vcolumn = (
+ torch.mean(v, -2).reshape([B, self.num_heads, -1,
+ W]).permute([0, 1, 3, 2]))
+
+ attn_column = torch.matmul(qcolumn, kcolumn) * self.scale
+ attn_column = nn.functional.softmax(attn_column, dim=-1)
+
+ xx_column = torch.matmul(attn_column, vcolumn) # B nH W C
+ xx_column = self.proj_encode_column(
+ xx_column.permute([0, 1, 3, 2]).reshape([B, self.dh, 1, W]))
+
+ xx = torch.add(xx_row, xx_column) # [B, self.dh, H, W]
+ xx = torch.add(v, xx)
+
+ xx = self.proj(xx)
+ xx = self.sigmoid(xx) * qkv
+ if self.stride_attention:
+ xx = F.interpolate(xx, size=(H_ori, W_ori), mode='bilinear')
+
+ return xx
+
+
+class MLP(nn.Module):
+ """the Multilayer Perceptron.
+
+ Args:
+ in_features (int): the input feature.
+ hidden_features (int): the hidden feature.
+ out_features (int): the output feature.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='PReLU').
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN')
+ drop (float): Probability of an element to be zeroed.
+ Default 0.0
+ """
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_cfg=None,
+ norm_cfg=dict(type='BN'),
+ drop=0.0,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = ConvModule(
+ in_features,
+ hidden_features,
+ kernel_size=1,
+ bias=False,
+ norm_cfg=norm_cfg,
+ act_cfg=None,
+ )
+ self.dwconv = ConvModule(
+ hidden_features,
+ hidden_features,
+ kernel_size=3,
+ padding=1,
+ groups=hidden_features,
+ norm_cfg=None,
+ act_cfg=act_cfg,
+ )
+
+ self.fc2 = ConvModule(
+ hidden_features,
+ out_features,
+ 1,
+ bias=False,
+ norm_cfg=norm_cfg,
+ act_cfg=None,
+ )
+ self.drop = build_dropout(dict(type='Dropout', drop_prob=drop))
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.dwconv(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class FusionBlock(nn.Module):
+ """The feature fusion block.
+
+ Args:
+ in_channel (int): the input channel.
+ out_channel (int): the output channel.
+ embed_dim (int): embedding dimension.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN')
+ """
+
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ embed_dim,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ ) -> None:
+ super().__init__()
+ self.local_embedding = ConvModule(
+ in_channels=in_channel,
+ out_channels=embed_dim,
+ kernel_size=1,
+ bias=False,
+ norm_cfg=norm_cfg,
+ act_cfg=None,
+ )
+
+ self.global_act = ConvModule(
+ in_channels=out_channel,
+ out_channels=embed_dim,
+ kernel_size=1,
+ bias=False,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg if act_cfg is not None else None,
+ )
+
+ def forward(self, x_l, x_g):
+ """
+ x_g: global features
+ x_l: local features
+ """
+ B, C, H, W = x_l.shape
+
+ local_feat = self.local_embedding(x_l)
+ global_act = self.global_act(x_g)
+ sig_act = F.interpolate(
+ global_act, size=(H, W), mode='bilinear', align_corners=False)
+
+ out = local_feat * sig_act
+
+ return out
+
+
+class InjectionMultiSumallmultiallsum(nn.Module):
+ """the Aggregate Attention Module.
+
+ Args:
+ in_channels (tuple): the input channel.
+ out_channels (int): the output channel.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN')
+ """
+
+ def __init__(
+ self,
+ in_channels=(64, 128, 256, 384),
+ out_channels=256,
+ act_cfg=dict(type='Sigmoid'),
+ norm_cfg=dict(type='BN'),
+ ):
+ super().__init__()
+ self.embedding_list = nn.ModuleList()
+ self.act_embedding_list = nn.ModuleList()
+ self.act_list = nn.ModuleList()
+ for i in range(len(in_channels)):
+ self.embedding_list.append(
+ ConvModule(
+ in_channels=in_channels[i],
+ out_channels=out_channels,
+ kernel_size=1,
+ bias=False,
+ norm_cfg=norm_cfg,
+ act_cfg=None,
+ ))
+ self.act_embedding_list.append(
+ ConvModule(
+ in_channels=in_channels[i],
+ out_channels=out_channels,
+ kernel_size=1,
+ bias=False,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ ))
+
+ def forward(self, inputs): # x_x8, x_x16, x_x32, x_x64
+ low_feat1 = F.interpolate(inputs[0], scale_factor=0.5, mode='bilinear')
+ low_feat1_act = self.act_embedding_list[0](low_feat1)
+ low_feat1 = self.embedding_list[0](low_feat1)
+
+ low_feat2 = F.interpolate(
+ inputs[1], size=low_feat1.shape[-2:], mode='bilinear')
+ low_feat2_act = self.act_embedding_list[1](low_feat2) # x16
+ low_feat2 = self.embedding_list[1](low_feat2)
+
+ high_feat_act = F.interpolate(
+ self.act_embedding_list[2](inputs[2]),
+ size=low_feat2.shape[2:],
+ mode='bilinear',
+ )
+ high_feat = F.interpolate(
+ self.embedding_list[2](inputs[2]),
+ size=low_feat2.shape[2:],
+ mode='bilinear')
+
+ res = (
+ low_feat1_act * low_feat2_act * high_feat_act *
+ (low_feat1 + low_feat2) + high_feat)
+
+ return res
+
+
+class InjectionMultiSumallmultiallsumSimpx8(nn.Module):
+ """the Aggregate Attention Module.
+
+ Args:
+ in_channels (tuple): the input channel.
+ out_channels (int): the output channel.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN')
+ """
+
+ def __init__(
+ self,
+ in_channels=(64, 128, 256, 384),
+ out_channels=256,
+ act_cfg=dict(type='Sigmoid'),
+ norm_cfg=dict(type='BN'),
+ ):
+ super().__init__()
+ self.embedding_list = nn.ModuleList()
+ self.act_embedding_list = nn.ModuleList()
+ self.act_list = nn.ModuleList()
+ for i in range(len(in_channels)):
+ if i != 1:
+ self.embedding_list.append(
+ ConvModule(
+ in_channels=in_channels[i],
+ out_channels=out_channels,
+ kernel_size=1,
+ bias=False,
+ norm_cfg=norm_cfg,
+ act_cfg=None,
+ ))
+ if i != 0:
+ self.act_embedding_list.append(
+ ConvModule(
+ in_channels=in_channels[i],
+ out_channels=out_channels,
+ kernel_size=1,
+ bias=False,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ ))
+
+ def forward(self, inputs):
+ # x_x8, x_x16, x_x32
+ low_feat1 = self.embedding_list[0](inputs[0])
+
+ low_feat2 = F.interpolate(
+ inputs[1], size=low_feat1.shape[-2:], mode='bilinear')
+ low_feat2_act = self.act_embedding_list[0](low_feat2)
+
+ high_feat_act = F.interpolate(
+ self.act_embedding_list[1](inputs[2]),
+ size=low_feat2.shape[2:],
+ mode='bilinear',
+ )
+ high_feat = F.interpolate(
+ self.embedding_list[1](inputs[2]),
+ size=low_feat2.shape[2:],
+ mode='bilinear')
+
+ res = low_feat2_act * high_feat_act * low_feat1 + high_feat
+
+ return res
+
+
+def _make_divisible(v, divisor=8, min_value=None):
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+@MODELS.register_module()
+class Hardsigmoid(nn.Module):
+ """the hardsigmoid activation.
+
+ Args:
+ slope (float, optional): The slope of hardsigmoid function.
+ Default is 0.1666667.
+ offset (float, optional): The offset of hardsigmoid function.
+ Default is 0.5.
+ inplace (bool): can optionally do the operation in-place.
+ Default: ``False``
+ """
+
+ def __init__(self, slope=0.1666667, offset=0.5, inplace=False):
+ super().__init__()
+ self.slope = slope
+ self.offset = offset
+
+ def forward(self, x):
+ return (x * self.slope + self.offset).clamp(0, 1)
diff --git a/projects/pp_mobileseg/configs/_base_/datasets/ade20k.py b/projects/pp_mobileseg/configs/_base_/datasets/ade20k.py
new file mode 100644
index 00000000000..48340d11eea
--- /dev/null
+++ b/projects/pp_mobileseg/configs/_base_/datasets/ade20k.py
@@ -0,0 +1,68 @@
+# dataset settings
+dataset_type = 'ADE20KDataset'
+data_root = 'data/ade/ADEChallengeData2016'
+crop_size = (512, 512)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(
+ type='RandomResize',
+ scale=(2048, 512),
+ ratio_range=(0.5, 2.0),
+ keep_ratio=True),
+ dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='PhotoMetricDistortion'),
+ dict(type='PackSegInputs')
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=(2048, 512), keep_ratio=True),
+ # add loading annotation after ``Resize`` because ground truth
+ # does not need to do resize data transform
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(type='PackSegInputs')
+]
+img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
+tta_pipeline = [
+ dict(type='LoadImageFromFile', backend_args=None),
+ dict(
+ type='TestTimeAug',
+ transforms=[
+ [
+ dict(type='Resize', scale_factor=r, keep_ratio=True)
+ for r in img_ratios
+ ],
+ [
+ dict(type='RandomFlip', prob=0., direction='horizontal'),
+ dict(type='RandomFlip', prob=1., direction='horizontal')
+ ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
+ ])
+]
+train_dataloader = dict(
+ batch_size=4,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='InfiniteSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='images/training', seg_map_path='annotations/training'),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_prefix=dict(
+ img_path='images/validation',
+ seg_map_path='annotations/validation'),
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+
+val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
+test_evaluator = val_evaluator
diff --git a/projects/pp_mobileseg/configs/_base_/default_runtime.py b/projects/pp_mobileseg/configs/_base_/default_runtime.py
new file mode 100644
index 00000000000..272b4d24679
--- /dev/null
+++ b/projects/pp_mobileseg/configs/_base_/default_runtime.py
@@ -0,0 +1,15 @@
+default_scope = 'mmseg'
+env_cfg = dict(
+ cudnn_benchmark=True,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+vis_backends = [dict(type='LocalVisBackend')]
+visualizer = dict(
+ type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+log_processor = dict(by_epoch=False)
+log_level = 'INFO'
+load_from = None
+resume = False
+
+tta_model = dict(type='SegTTAModel')
diff --git a/projects/pp_mobileseg/configs/_base_/models/pp_mobile.py b/projects/pp_mobileseg/configs/_base_/models/pp_mobile.py
new file mode 100644
index 00000000000..0c7695636f6
--- /dev/null
+++ b/projects/pp_mobileseg/configs/_base_/models/pp_mobile.py
@@ -0,0 +1,47 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+data_preprocessor = dict(
+ type='SegDataPreProcessor',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ bgr_to_rgb=True,
+ pad_val=0,
+ seg_pad_val=255)
+
+model = dict(
+ type='EncoderDecoder',
+ data_preprocessor=data_preprocessor,
+ # pretrained='open-mmlab://resnet50_v1c',
+ backbone=dict(
+ type='StrideFormer',
+ mobileV3_cfg=[
+ # k t c, s
+ [[3, 16, 16, True, 'ReLU', 1], [3, 64, 32, False, 'ReLU', 2],
+ [3, 96, 32, False, 'ReLU', 1]], # cfg1
+ [[5, 128, 64, True, 'HSwish', 2], [5, 240, 64, True, 'HSwish',
+ 1]], # cfg2
+ [[5, 384, 128, True, 'HSwish', 2],
+ [5, 384, 128, True, 'HSwish', 1]], # cfg3
+ [[5, 768, 192, True, 'HSwish', 2],
+ [5, 768, 192, True, 'HSwish', 1]], # cfg4
+ ],
+ channels=[16, 32, 64, 128, 192],
+ depths=[3, 3],
+ embed_dims=[128, 192],
+ num_heads=8,
+ inj_type='AAMSx8',
+ out_feat_chs=[64, 128, 192],
+ act_cfg=dict(type='ReLU6'),
+ ),
+ decode_head=dict(
+ type='PPMobileSegHead',
+ num_classes=150,
+ in_channels=256,
+ dropout_ratio=0.1,
+ use_dw=True,
+ act_cfg=dict(type='ReLU'),
+ align_corners=False),
+
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/projects/pp_mobileseg/configs/_base_/schedules/schedule_80k.py b/projects/pp_mobileseg/configs/_base_/schedules/schedule_80k.py
new file mode 100644
index 00000000000..0dcd6c4d1bc
--- /dev/null
+++ b/projects/pp_mobileseg/configs/_base_/schedules/schedule_80k.py
@@ -0,0 +1,24 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
+optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
+# learning policy
+param_scheduler = [
+ dict(
+ type='PolyLR',
+ eta_min=1e-4,
+ power=0.9,
+ begin=0,
+ end=80000,
+ by_epoch=False)
+]
+# training schedule for 80k
+train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=8000)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=8000),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'))
diff --git a/projects/pp_mobileseg/configs/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_base.py b/projects/pp_mobileseg/configs/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_base.py
new file mode 100644
index 00000000000..4b68a927e20
--- /dev/null
+++ b/projects/pp_mobileseg/configs/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_base.py
@@ -0,0 +1,18 @@
+_base_ = [
+ '../_base_/models/pp_mobile.py', '../_base_/datasets/ade20k.py',
+ '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
+]
+# the custom import path is determined by your workspace path (i.e., where you run the command from) # noqa
+custom_imports = dict(
+ imports=[
+ 'projects.pp_mobileseg.backbones', 'projects.pp_mobileseg.decode_head'
+ ],
+ allow_failed_imports=False)
+checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pp_mobileseg/pp_mobileseg_mobilenetv3_3rdparty-base-ed0be681.pth' # noqa
+crop_size = (512, 512)
+data_preprocessor = dict(size=crop_size, test_cfg=dict(size_divisor=32))
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ backbone=dict(init_cfg=dict(type='Pretrained', checkpoint=checkpoint)),
+ decode_head=dict(num_classes=150, upsample='intepolate'))
diff --git a/projects/pp_mobileseg/configs/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_tiny.py b/projects/pp_mobileseg/configs/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_tiny.py
new file mode 100644
index 00000000000..b78869e517a
--- /dev/null
+++ b/projects/pp_mobileseg/configs/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_tiny.py
@@ -0,0 +1,45 @@
+_base_ = [
+ '../_base_/models/pp_mobile.py', '../_base_/datasets/ade20k.py',
+ '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
+]
+# the custom import path is determined by your workspace path (i.e., where you run the command from) # noqa
+custom_imports = dict(
+ imports=[
+ 'projects.pp_mobileseg.backbones', 'projects.pp_mobileseg.decode_head'
+ ],
+ allow_failed_imports=False)
+checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pp_mobileseg/pp_mobileseg_mobilenetv3_3rdparty-tiny-e4b35e96.pth' # noqa
+crop_size = (512, 512)
+data_preprocessor = dict(size=crop_size, test_cfg=dict(size_divisor=32))
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ data_preprocessor=data_preprocessor,
+ backbone=dict(
+ init_cfg=dict(type='Pretrained', checkpoint=checkpoint),
+ type='StrideFormer',
+ mobileV3_cfg=[
+ # k t c, s
+ [[3, 16, 16, True, 'ReLU', 1], [3, 64, 32, False, 'ReLU', 2],
+ [3, 48, 24, False, 'ReLU', 1]], # cfg1
+ [[5, 96, 32, True, 'HSwish', 2], [5, 96, 32, True, 'HSwish',
+ 1]], # cfg2
+ [[5, 160, 64, True, 'HSwish', 2], [5, 160, 64, True, 'HSwish',
+ 1]], # cfg3
+ [[3, 384, 128, True, 'HSwish', 2],
+ [3, 384, 128, True, 'HSwish', 1]], # cfg4
+ ],
+ channels=[16, 24, 32, 64, 128],
+ depths=[2, 2],
+ embed_dims=[64, 128],
+ num_heads=4,
+ inj_type='AAM',
+ out_feat_chs=[32, 64, 128],
+ act_cfg=dict(type='ReLU6'),
+ ),
+ decode_head=dict(
+ num_classes=150,
+ in_channels=256,
+ use_dw=True,
+ act_cfg=dict(type='ReLU'),
+ upsample='intepolate'),
+)
diff --git a/projects/pp_mobileseg/decode_head/__init__.py b/projects/pp_mobileseg/decode_head/__init__.py
new file mode 100644
index 00000000000..6f71b784e15
--- /dev/null
+++ b/projects/pp_mobileseg/decode_head/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .pp_mobileseg_head import PPMobileSegHead
+
+__all__ = [
+ 'PPMobileSegHead',
+]
diff --git a/projects/pp_mobileseg/decode_head/pp_mobileseg_head.py b/projects/pp_mobileseg/decode_head/pp_mobileseg_head.py
new file mode 100644
index 00000000000..243f0263729
--- /dev/null
+++ b/projects/pp_mobileseg/decode_head/pp_mobileseg_head.py
@@ -0,0 +1,94 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, build_conv_layer
+from torch import Tensor
+
+from mmseg.registry import MODELS
+
+
+@MODELS.register_module()
+class PPMobileSegHead(nn.Module):
+ """the segmentation head.
+
+ Args:
+ num_classes (int): the classes num.
+ in_channels (int): the input channels.
+ use_dw (bool): if to use deepwith convolution.
+ dropout_ratio (float): Probability of an element to be zeroed.
+ Default 0.0。
+ align_corners (bool, optional): Geometrically, we consider the pixels
+ of the input and output as squares rather than points.
+ upsample (str): the upsample method.
+ out_channels (int): the output channel.
+ conv_cfg (dict): Config dict for convolution layer.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ use_dw=True,
+ dropout_ratio=0.1,
+ align_corners=False,
+ upsample='intepolate',
+ out_channels=None,
+ conv_cfg=dict(type='Conv'),
+ act_cfg=dict(type='ReLU'),
+ norm_cfg=dict(type='BN')):
+
+ super().__init__()
+ self.align_corners = align_corners
+ self.last_channels = in_channels
+ self.upsample = upsample
+ self.num_classes = num_classes
+ self.out_channels = out_channels
+ self.linear_fuse = ConvModule(
+ in_channels=self.last_channels,
+ out_channels=self.last_channels,
+ kernel_size=1,
+ bias=False,
+ groups=self.last_channels if use_dw else 1,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.dropout = nn.Dropout2d(dropout_ratio)
+ self.conv_seg = build_conv_layer(
+ conv_cfg, self.last_channels, self.num_classes, kernel_size=1)
+
+ def forward(self, x):
+ x, x_hw = x[0], x[1]
+ x = self.linear_fuse(x)
+ x = self.dropout(x)
+ x = self.conv_seg(x)
+ if self.upsample == 'intepolate' or self.training or \
+ self.num_classes < 30:
+ x = F.interpolate(
+ x, x_hw, mode='bilinear', align_corners=self.align_corners)
+ elif self.upsample == 'vim':
+ labelset = torch.unique(torch.argmax(x, 1))
+ x = torch.gather(x, 1, labelset)
+ x = F.interpolate(
+ x, x_hw, mode='bilinear', align_corners=self.align_corners)
+
+ pred = torch.argmax(x, 1)
+ pred_retrieve = torch.zeros(pred.shape, dtype=torch.int32)
+ for i, val in enumerate(labelset):
+ pred_retrieve[pred == i] = labelset[i].cast('int32')
+
+ x = pred_retrieve
+ else:
+ raise NotImplementedError(self.upsample, ' is not implemented')
+
+ return [x]
+
+ def predict(self, inputs, batch_img_metas: List[dict], test_cfg,
+ **kwargs) -> List[Tensor]:
+ """Forward function for testing, only ``pam_cam`` is used."""
+ seg_logits = self.forward(inputs)[0]
+ return seg_logits
diff --git a/projects/pp_mobileseg/inference_onnx.py b/projects/pp_mobileseg/inference_onnx.py
new file mode 100644
index 00000000000..139d1b13243
--- /dev/null
+++ b/projects/pp_mobileseg/inference_onnx.py
@@ -0,0 +1,203 @@
+import argparse
+import time
+from typing import List, Tuple
+
+import cv2
+import loguru
+import numpy as np
+import onnxruntime as ort
+
+logger = loguru.logger
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description='PP_Mobileseg ONNX inference demo.')
+ parser.add_argument('onnx_file', help='ONNX file path')
+ parser.add_argument('image_file', help='Input image file path')
+ parser.add_argument(
+ '--input-size',
+ type=int,
+ nargs='+',
+ default=[512, 512],
+ help='input image size')
+ parser.add_argument(
+ '--device', help='device type for inference', default='cpu')
+ parser.add_argument(
+ '--save-path',
+ help='path to save the output image',
+ default='output.jpg')
+ args = parser.parse_args()
+ return args
+
+
+def preprocess(
+ img: np.ndarray, input_size: Tuple[int, int] = (512, 512)
+) -> Tuple[np.ndarray, np.ndarray]:
+ """Preprocess image for inference."""
+ img_shape = img.shape[:2]
+ # Resize
+ resized_img = cv2.resize(img, input_size)
+
+ # Normalize
+ mean = np.array([123.575, 116.28, 103.53], dtype=np.float32)
+ std = np.array([58.395, 57.12, 57.375], dtype=np.float32)
+ resized_img = (resized_img - mean) / std
+
+ return resized_img, img_shape
+
+
+def build_session(onnx_file: str, device: str = 'cpu') -> ort.InferenceSession:
+ """Build onnxruntime session.
+
+ Args:
+ onnx_file (str): ONNX file path.
+ device (str): Device type for inference.
+
+ Returns:
+ sess (ort.InferenceSession): ONNXRuntime session.
+ """
+ providers = ['CPUExecutionProvider'
+ ] if device == 'cpu' else ['CUDAExecutionProvider']
+ sess = ort.InferenceSession(path_or_bytes=onnx_file, providers=providers)
+
+ return sess
+
+
+def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
+ """Inference RTMPose model.
+
+ Args:
+ sess (ort.InferenceSession): ONNXRuntime session.
+ img (np.ndarray): Input image in shape.
+
+ Returns:
+ outputs (np.ndarray): Output of RTMPose model.
+ """
+ # build input
+ input_img = [img.transpose(2, 0, 1).astype(np.float32)]
+
+ # build output
+ sess_input = {sess.get_inputs()[0].name: input_img}
+ sess_output = []
+ for out in sess.get_outputs():
+ sess_output.append(out.name)
+
+ # inference
+ outputs = sess.run(output_names=sess_output, input_feed=sess_input)
+
+ return outputs
+
+
+def postprocess(outputs: List[np.ndarray],
+ origin_shape: Tuple[int, int]) -> np.ndarray:
+ """Postprocess outputs of PP_Mobileseg model.
+
+ Args:
+ outputs (List[np.ndarray]): Outputs of PP_Mobileseg model.
+ origin_shape (Tuple[int, int]): Input size of PP_Mobileseg model.
+
+ Returns:
+ seg_map (np.ndarray): Segmentation map.
+ """
+ seg_map = outputs[0][0][0]
+ seg_map = cv2.resize(seg_map.astype(np.float32), origin_shape)
+ return seg_map
+
+
+def visualize(img: np.ndarray,
+ seg_map: np.ndarray,
+ filename: str = 'output.jpg',
+ opacity: float = 0.8) -> np.ndarray:
+ assert 0.0 <= opacity <= 1.0, 'opacity should be in range [0, 1]'
+ palette = np.array(PALETTE)
+ color_seg = np.zeros((seg_map.shape[0], seg_map.shape[1], 3),
+ dtype=np.uint8)
+ for label, color in enumerate(palette):
+ color_seg[seg_map == label, :] = color
+ # convert to BGR
+ color_seg = color_seg[..., ::-1]
+
+ img = img * (1 - opacity) + color_seg * opacity
+ cv2.imwrite(filename, img)
+
+ return img
+
+
+def main():
+ args = parse_args()
+ logger.info('Start running model inference...')
+
+ # read image from file
+ logger.info(f'1. Read image from file {args.image_file}...')
+ img = cv2.imread(args.image_file)
+
+ # build onnx model
+ logger.info(f'2. Build onnx model from {args.onnx_file}...')
+ sess = build_session(args.onnx_file, args.device)
+
+ # preprocess
+ logger.info('3. Preprocess image...')
+ model_input_size = tuple(args.input_size)
+ assert len(model_input_size) == 2
+ resized_img, origin_shape = preprocess(img, model_input_size)
+
+ # inference
+ logger.info('4. Inference...')
+ start = time.time()
+ outputs = inference(sess, resized_img)
+ logger.info(f'Inference time: {time.time() - start:.4f}s')
+
+ # postprocess
+ logger.info('5. Postprocess...')
+ h, w = origin_shape
+ seg_map = postprocess(outputs, (w, h))
+
+ # visualize
+ logger.info('6. Visualize...')
+ visualize(img, seg_map, args.save_path)
+
+ logger.info('Done...')
+
+
+PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
+ [102, 255, 0], [92, 0, 255]]
+
+if __name__ == '__main__':
+ main()
diff --git a/projects/sam_inference_demo/README.md b/projects/sam_inference_demo/README.md
new file mode 100644
index 00000000000..f8077b8729f
--- /dev/null
+++ b/projects/sam_inference_demo/README.md
@@ -0,0 +1,40 @@
+# Introducing the Segment Anything Model (SAM) Inference Demo!
+
+Welcome to the Segment Anything (SA) Inference Demo, a user-friendly implementation based on the original Segment Anything project. Our demo allows you to experience the power and versatility of the Segment Anything Model (SAM) through an easy-to-use API.
+
+With this inference demo, you can explore the capabilities of the Segment Anything Model and witness its effectiveness in various tasks and image distributions. For more information on the original project, dataset, and model, please visit the official website at https://segment-anything.com.
+
+### Prerequisites
+
+- Python 3.10
+- PyTorch 1.13
+- MMEngine >= v0.7.2
+- MMCV >= v2.0.0
+
+### Installation
+
+We assume that you have already installed PyTorch. If not, please follow the instructions on the [PyTorch website](https://pytorch.org/).
+
+**1. Install MMEngine & MMCV**
+
+```shell
+pip install openmim
+mim install mmengine
+mim install 'mmcv>=2.0.0'
+```
+
+**2. Install MMPretrain**
+
+```shell
+pip install git+https://github.com/open-mmlab/mmpretrain.git@dev
+```
+
+**3. Install MMSegmentation**
+
+```shell
+pip install mmsegmentation
+```
+
+### Usage
+
+Open the `sam_image_demo.ipynb` notebook and follow the instructions to run the demo.
diff --git a/projects/sam_inference_demo/sam/__init__.py b/projects/sam_inference_demo/sam/__init__.py
new file mode 100644
index 00000000000..82b6b78469c
--- /dev/null
+++ b/projects/sam_inference_demo/sam/__init__.py
@@ -0,0 +1,2 @@
+from .modeling import * # noqa
+from .utils import * # noqa
diff --git a/projects/sam_inference_demo/sam/modeling/__init__.py b/projects/sam_inference_demo/sam/modeling/__init__.py
new file mode 100644
index 00000000000..9892a6b085d
--- /dev/null
+++ b/projects/sam_inference_demo/sam/modeling/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .mask_decoder import MaskDecoder
+from .prompt_encoder import PromptEncoder
+from .sam import SAM
+from .transformer import TwoWayTransformer
+
+__all__ = ['SAM', 'MaskDecoder', 'PromptEncoder', 'TwoWayTransformer']
diff --git a/projects/sam_inference_demo/sam/modeling/common.py b/projects/sam_inference_demo/sam/modeling/common.py
new file mode 100644
index 00000000000..d2892761122
--- /dev/null
+++ b/projects/sam_inference_demo/sam/modeling/common.py
@@ -0,0 +1,45 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Type
+
+import torch
+import torch.nn as nn
+
+
+class MLPBlock(nn.Module):
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ mlp_dim: int,
+ act: Type[nn.Module] = nn.GELU,
+ ) -> None:
+ super().__init__()
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
+ self.act = act()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.lin2(self.act(self.lin1(x)))
+
+
+# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
+# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
+class LayerNorm2d(nn.Module):
+
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(num_channels))
+ self.bias = nn.Parameter(torch.zeros(num_channels))
+ self.eps = eps
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
diff --git a/projects/sam_inference_demo/sam/modeling/mask_decoder.py b/projects/sam_inference_demo/sam/modeling/mask_decoder.py
new file mode 100644
index 00000000000..9ad616b5899
--- /dev/null
+++ b/projects/sam_inference_demo/sam/modeling/mask_decoder.py
@@ -0,0 +1,196 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Borrowed from https://github.com/facebookresearch/segment-anything
+
+from typing import List, Tuple
+
+import torch
+from torch import Tensor, nn
+from torch.nn import functional as F
+
+from mmseg.registry import MODELS
+from .common import LayerNorm2d
+
+
+@MODELS.register_module()
+class MaskDecoder(nn.Module):
+
+ def __init__(
+ self,
+ *,
+ transformer_dim: int,
+ transformer: dict,
+ num_multimask_outputs: int = 3,
+ act_cfg: dict = dict(type='GELU'),
+ iou_head_depth: int = 3,
+ iou_head_hidden_dim: int = 256,
+ ) -> None:
+ """Predicts masks given an image and prompt embeddings, using a
+ tranformer architecture.
+
+ Borrowed from https://github.com/facebookresearch/segment-anything
+
+ Arguments:
+ transformer_dim (int): the channel dimension of the transformer
+ transformer (nn.Module): the transformer used to predict masks
+ num_multimask_outputs (int): the number of masks to predict
+ when disambiguating masks
+ activation (nn.Module): the type of activation to use when
+ upscaling masks
+ iou_head_depth (int): the depth of the MLP used to predict
+ mask quality
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
+ used to predict mask quality
+ """
+ super().__init__()
+ self.transformer_dim = transformer_dim
+ self.transformer = MODELS.build(transformer)
+
+ self.num_multimask_outputs = num_multimask_outputs
+
+ self.iou_token = nn.Embedding(1, transformer_dim)
+ self.num_mask_tokens = num_multimask_outputs + 1
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
+
+ activation = MODELS.build(act_cfg)
+ self.output_upscaling = nn.Sequential(
+ nn.ConvTranspose2d(
+ transformer_dim, transformer_dim // 4, kernel_size=2,
+ stride=2),
+ LayerNorm2d(transformer_dim // 4),
+ activation,
+ nn.ConvTranspose2d(
+ transformer_dim // 4,
+ transformer_dim // 8,
+ kernel_size=2,
+ stride=2),
+ activation,
+ )
+ self.output_hypernetworks_mlps = nn.ModuleList([
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
+ for i in range(self.num_mask_tokens)
+ ])
+
+ self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim,
+ self.num_mask_tokens, iou_head_depth)
+
+ def forward(
+ self,
+ image_embeddings: Tensor,
+ image_pe: Tensor,
+ sparse_prompt_embeddings: Tensor,
+ dense_prompt_embeddings: Tensor,
+ multimask_output: bool,
+ ) -> Tuple[Tensor, Tensor]:
+ """Predict masks given image and prompt embeddings.
+
+ Borrowed from https://github.com/facebookresearch/segment-anything
+
+ Arguments:
+ image_embeddings (Tensor): the embeddings from the image encoder
+ image_pe (Tensor): positional encoding with the shape of
+ image_embeddings
+ sparse_prompt_embeddings (Tensor): the embeddings of
+ the points and boxes
+ dense_prompt_embeddings (Tensor): the embeddings of the mask inputs
+ multimask_output (bool): Whether to return multiple masks or a single
+ mask.
+
+ Returns:
+ Tensor: batched predicted masks
+ Tensor: batched predictions of mask quality
+ """
+ masks, iou_pred = self.predict_masks(
+ image_embeddings=image_embeddings,
+ image_pe=image_pe,
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
+ dense_prompt_embeddings=dense_prompt_embeddings,
+ )
+
+ # Select the correct mask or masks for output
+ if multimask_output:
+ mask_slice = slice(1, None)
+ else:
+ mask_slice = slice(0, 1)
+ masks = masks[:, mask_slice, :, :]
+ iou_pred = iou_pred[:, mask_slice]
+
+ # Prepare output
+ return masks, iou_pred
+
+ def predict_masks(
+ self,
+ image_embeddings: Tensor,
+ image_pe: Tensor,
+ sparse_prompt_embeddings: Tensor,
+ dense_prompt_embeddings: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ """Predicts masks.
+
+ See 'forward' for more details.
+ """
+ # Concatenate output tokens
+ output_tokens = torch.cat(
+ [self.iou_token.weight, self.mask_tokens.weight], dim=0)
+ output_tokens = output_tokens.unsqueeze(0).expand(
+ sparse_prompt_embeddings.size(0), -1, -1)
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+
+ # Expand per-image data in batch direction to be per-mask
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+ src = src + dense_prompt_embeddings
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+ b, c, h, w = src.shape
+
+ # Run the transformer
+ hs, src = self.transformer(src, pos_src, tokens)
+ iou_token_out = hs[:, 0, :]
+ mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :]
+
+ # Upscale mask embeddings and predict masks using the mask tokens
+ src = src.transpose(1, 2).view(b, c, h, w)
+ upscaled_embedding = self.output_upscaling(src)
+ hyper_in_list: List[Tensor] = []
+ for i in range(self.num_mask_tokens):
+ hyper_in_list.append(self.output_hypernetworks_mlps[i](
+ mask_tokens_out[:, i, :]))
+ hyper_in = torch.stack(hyper_in_list, dim=1)
+ b, c, h, w = upscaled_embedding.shape
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(
+ b, -1, h, w)
+
+ # Generate mask quality predictions
+ iou_pred = self.iou_prediction_head(iou_token_out)
+
+ return masks, iou_pred
+
+
+# Lightly adapted from
+# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
+class MLP(nn.Module):
+
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ output_dim: int,
+ num_layers: int,
+ sigmoid_output: bool = False,
+ ) -> None:
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+ self.sigmoid_output = sigmoid_output
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ if self.sigmoid_output:
+ x = F.sigmoid(x)
+ return x
diff --git a/projects/sam_inference_demo/sam/modeling/prompt_encoder.py b/projects/sam_inference_demo/sam/modeling/prompt_encoder.py
new file mode 100644
index 00000000000..6b7c0833871
--- /dev/null
+++ b/projects/sam_inference_demo/sam/modeling/prompt_encoder.py
@@ -0,0 +1,227 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Borrowed from https://github.com/facebookresearch/segment-anything
+
+from typing import Any, Optional, Tuple, Type
+
+import numpy as np
+import torch
+from torch import nn
+
+from mmseg.registry import MODELS
+from .common import LayerNorm2d
+
+
+@MODELS.register_module()
+class PromptEncoder(nn.Module):
+
+ def __init__(
+ self,
+ embed_dim: int,
+ image_embedding_size: Tuple[int, int],
+ input_image_size: Tuple[int, int],
+ mask_in_chans: int,
+ activation: Type[nn.Module] = nn.GELU,
+ ) -> None:
+ """Encodes prompts for input to SAM's mask decoder.
+
+ Arguments:
+ embed_dim (int): The prompts' embedding dimension
+ image_embedding_size (tuple(int, int)): The spatial size of the
+ image embedding, as (H, W).
+ input_image_size (int): The padded size of the image as input
+ to the image encoder, as (H, W).
+ mask_in_chans (int): The number of hidden channels used for
+ encoding input masks.
+ activation (nn.Module): The activation to use when encoding
+ input masks.
+ """
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.input_image_size = input_image_size
+ self.image_embedding_size = image_embedding_size
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
+
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
+ point_embeddings = [
+ nn.Embedding(1, embed_dim)
+ for i in range(self.num_point_embeddings)
+ ]
+ self.point_embeddings = nn.ModuleList(point_embeddings)
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
+
+ self.mask_input_size = (4 * image_embedding_size[0],
+ 4 * image_embedding_size[1])
+ self.mask_downscaling = nn.Sequential(
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
+ LayerNorm2d(mask_in_chans // 4),
+ activation(),
+ nn.Conv2d(
+ mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
+ LayerNorm2d(mask_in_chans),
+ activation(),
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
+ )
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
+
+ def get_dense_pe(self) -> torch.Tensor:
+ """Returns the positional encoding used to encode point prompts,
+ applied to a dense set of points the shape of the image encoding.
+
+ Returns:
+ torch.Tensor: Positional encoding with shape
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
+ """
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
+
+ def _embed_points(
+ self,
+ points: torch.Tensor,
+ labels: torch.Tensor,
+ pad: bool,
+ ) -> torch.Tensor:
+ """Embeds point prompts."""
+ points = points + 0.5 # Shift to center of pixel
+ if pad:
+ padding_point = torch.zeros((points.shape[0], 1, 2),
+ device=points.device)
+ padding_label = -torch.ones(
+ (labels.shape[0], 1), device=labels.device)
+ points = torch.cat([points, padding_point], dim=1)
+ labels = torch.cat([labels, padding_label], dim=1)
+ point_embedding = self.pe_layer.forward_with_coords(
+ points, self.input_image_size)
+ point_embedding[labels == -1] = 0.0
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
+ return point_embedding
+
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
+ """Embeds box prompts."""
+ boxes = boxes + 0.5 # Shift to center of pixel
+ coords = boxes.reshape(-1, 2, 2)
+ corner_embedding = self.pe_layer.forward_with_coords(
+ coords, self.input_image_size)
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
+ return corner_embedding
+
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
+ """Embeds mask inputs."""
+ mask_embedding = self.mask_downscaling(masks)
+ return mask_embedding
+
+ def _get_batch_size(
+ self,
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+ boxes: Optional[torch.Tensor],
+ masks: Optional[torch.Tensor],
+ ) -> int:
+ """Gets the batch size of the output given the batch size of the input
+ prompts."""
+ if points is not None:
+ return points[0].shape[0]
+ elif boxes is not None:
+ return boxes.shape[0]
+ elif masks is not None:
+ return masks.shape[0]
+ else:
+ return 1
+
+ def _get_device(self) -> torch.device:
+ return self.point_embeddings[0].weight.device
+
+ def forward(
+ self,
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+ boxes: Optional[torch.Tensor],
+ masks: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Embeds different types of prompts, returning both sparse and dense
+ embeddings.
+
+ Arguments:
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
+ and labels to embed.
+ boxes (torch.Tensor or none): boxes to embed
+ masks (torch.Tensor or none): masks to embed
+
+ Returns:
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
+ BxNx(embed_dim), where N is determined by the number of input points
+ and boxes.
+ torch.Tensor: dense embeddings for the masks, in the shape
+ Bx(embed_dim)x(embed_H)x(embed_W)
+ """ # noqa
+ bs = self._get_batch_size(points, boxes, masks)
+ sparse_embeddings = torch.empty((bs, 0, self.embed_dim),
+ device=self._get_device())
+ if points is not None:
+ coords, labels = points
+ point_embeddings = self._embed_points(
+ coords, labels, pad=(boxes is None))
+ sparse_embeddings = torch.cat(
+ [sparse_embeddings, point_embeddings], dim=1)
+ if boxes is not None:
+ box_embeddings = self._embed_boxes(boxes)
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings],
+ dim=1)
+
+ if masks is not None:
+ dense_embeddings = self._embed_masks(masks)
+ else:
+ dense_embeddings = self.no_mask_embed.weight.reshape(
+ 1, -1, 1, 1).expand(bs, -1, self.image_embedding_size[0],
+ self.image_embedding_size[1])
+
+ return sparse_embeddings, dense_embeddings
+
+
+class PositionEmbeddingRandom(nn.Module):
+ """Positional encoding using random spatial frequencies."""
+
+ def __init__(self,
+ num_pos_feats: int = 64,
+ scale: Optional[float] = None) -> None:
+ super().__init__()
+ if scale is None or scale <= 0.0:
+ scale = 1.0
+ self.register_buffer(
+ 'positional_encoding_gaussian_matrix',
+ scale * torch.randn((2, num_pos_feats)),
+ )
+
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
+ """Positionally encode points that are normalized to [0,1]."""
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape # noqa
+ coords = 2 * coords - 1
+ coords = coords @ self.positional_encoding_gaussian_matrix
+ coords = 2 * np.pi * coords
+ # outputs d_1 x ... x d_n x C shape
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
+
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
+ """Generate positional encoding for a grid of the specified size."""
+ h, w = size
+ device: Any = self.positional_encoding_gaussian_matrix.device
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
+ y_embed = grid.cumsum(dim=0) - 0.5
+ x_embed = grid.cumsum(dim=1) - 0.5
+ y_embed = y_embed / h
+ x_embed = x_embed / w
+
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
+ return pe.permute(2, 0, 1) # C x H x W
+
+ def forward_with_coords(self, coords_input: torch.Tensor,
+ image_size: Tuple[int, int]) -> torch.Tensor:
+ """Positionally encode points that are not normalized to [0,1]."""
+ coords = coords_input.clone()
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
diff --git a/projects/sam_inference_demo/sam/modeling/sam.py b/projects/sam_inference_demo/sam/modeling/sam.py
new file mode 100644
index 00000000000..c61c1eca4e7
--- /dev/null
+++ b/projects/sam_inference_demo/sam/modeling/sam.py
@@ -0,0 +1,188 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Borrowed from https://github.com/facebookresearch/segment-anything
+
+from typing import Any, Dict, List, Tuple
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from mmseg.registry import MODELS
+from .mask_decoder import MaskDecoder
+from .prompt_encoder import PromptEncoder
+
+
+@MODELS.register_module()
+class SAM(nn.Module):
+ mask_threshold: float = 0.0
+ image_format: str = 'RGB'
+
+ def __init__(
+ self,
+ image_encoder_cfg: dict,
+ prompt_encoder_cfg: dict,
+ mask_decoder_cfg: dict,
+ pixel_mean: List[float] = [123.675, 116.28, 103.53],
+ pixel_std: List[float] = [58.395, 57.12, 57.375],
+ ) -> None:
+ """SAM predicts object masks from an image and input prompts. Borrowed
+ from https://github.com/facebookresearch/segment-anything.
+
+ Arguments:
+ image_encoder (ViTSAM): The backbone used to encode the
+ image into image embeddings that allow for efficient mask
+ prediction.
+ prompt_encoder (PromptEncoder): Encodes various types of input
+ prompts.
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings
+ and encoded prompts.
+ pixel_mean (list(float)): Mean values for normalizing pixels in the
+ input image.
+ pixel_std (list(float)): Std values for normalizing pixels in the
+ input image.
+ """
+ super().__init__()
+ self.image_encoder = MODELS.build(image_encoder_cfg)
+ self.prompt_encoder: PromptEncoder = MODELS.build(prompt_encoder_cfg)
+ self.mask_decoder: MaskDecoder = MODELS.build(mask_decoder_cfg)
+ self.register_buffer('pixel_mean',
+ torch.Tensor(pixel_mean).view(-1, 1, 1), False)
+ self.register_buffer('pixel_std',
+ torch.Tensor(pixel_std).view(-1, 1, 1), False)
+
+ @property
+ def device(self) -> Any:
+ return self.pixel_mean.device
+
+ @torch.no_grad()
+ def forward(
+ self,
+ batched_input: List[Dict[str, Any]],
+ multimask_output: bool,
+ ) -> List[Dict[str, torch.Tensor]]:
+ """Predicts masks end-to-end from provided images and prompts. If
+ prompts are not known in advance, using SamPredictor is recommended
+ over calling the model directly.
+
+ Borrowed from https://github.com/facebookresearch/segment-anything
+
+ Arguments:
+ batched_input (list(dict)): A list over input images, each a
+ dictionary with the following keys. A prompt key can be
+ excluded if it is not present.
+ 'image': The image as a torch tensor in 3xHxW format,
+ already transformed for input to the model.
+ 'original_size': (tuple(int, int)) The original size of
+ the image before transformation, as (H, W).
+ 'point_coords': (torch.Tensor) Batched point prompts for
+ this image, with shape BxNx2. Already transformed to the
+ input frame of the model.
+ 'point_labels': (torch.Tensor) Batched labels for point prompts,
+ with shape BxN.
+ 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
+ Already transformed to the input frame of the model.
+ 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
+ in the form Bx1xHxW.
+ multimask_output (bool): Whether the model should predict multiple
+ disambiguating masks, or return a single mask.
+
+ Returns:
+ (list(dict)): A list over input images, where each element is
+ as dictionary with the following keys.
+ 'masks': (torch.Tensor) Batched binary mask predictions,
+ with shape BxCxHxW, where B is the number of input prompts,
+ C is determiend by multimask_output, and (H, W) is the
+ original size of the image.
+ 'iou_predictions': (torch.Tensor) The model's predictions
+ of mask quality, in shape BxC.
+ 'low_res_logits': (torch.Tensor) Low resolution logits with
+ shape BxCxHxW, where H=W=256. Can be passed as mask input
+ to subsequent iterations of prediction.
+ """
+ input_images = torch.stack(
+ [self.preprocess(x['image']) for x in batched_input], dim=0)
+ image_embeddings = self.image_encoder(input_images)
+
+ outputs = []
+ for image_record, curr_embedding in zip(batched_input,
+ image_embeddings):
+ if 'point_coords' in image_record:
+ points = (image_record['point_coords'],
+ image_record['point_labels'])
+ else:
+ points = None
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
+ points=points,
+ boxes=image_record.get('boxes', None),
+ masks=image_record.get('mask_inputs', None),
+ )
+ low_res_masks, iou_predictions = self.mask_decoder(
+ image_embeddings=curr_embedding.unsqueeze(0),
+ image_pe=self.prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ )
+ masks = self.postprocess_masks(
+ low_res_masks,
+ input_size=image_record['image'].shape[-2:],
+ original_size=image_record['original_size'],
+ )
+ masks = masks > self.mask_threshold
+ outputs.append({
+ 'masks': masks,
+ 'iou_predictions': iou_predictions,
+ 'low_res_logits': low_res_masks,
+ })
+ return outputs
+
+ def postprocess_masks(
+ self,
+ masks: torch.Tensor,
+ input_size: Tuple[int, ...],
+ original_size: Tuple[int, ...],
+ ) -> torch.Tensor:
+ """Remove padding and upscale masks to the original image size.
+
+ Borrowed from https://github.com/facebookresearch/segment-anything
+
+ Arguments:
+ masks (torch.Tensor): Batched masks from the mask_decoder,
+ in BxCxHxW format.
+ input_size (tuple(int, int)): The size of the image input to the
+ model, in (H, W) format. Used to remove padding.
+ original_size (tuple(int, int)): The original size of the image
+ before resizing for input to the model, in (H, W) format.
+
+ Returns:
+ (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
+ is given by original_size.
+ """
+ masks = F.interpolate(
+ masks,
+ self.image_encoder.img_size,
+ mode='bilinear',
+ align_corners=False,
+ )
+ masks = masks[..., :input_size[0], :input_size[1]]
+ masks = F.interpolate(
+ masks, original_size, mode='bilinear', align_corners=False)
+ return masks
+
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
+ """Normalize pixel values and pad to a square input."""
+ # Normalize colors
+ x = (x - self.pixel_mean) / self.pixel_std
+
+ # Pad
+ h, w = x.shape[-2:]
+ img_size = max(self.image_encoder.img_size)
+ padh = img_size - h
+ padw = img_size - w
+ x = F.pad(x, (0, padw, 0, padh))
+ return x
diff --git a/projects/sam_inference_demo/sam/modeling/transformer.py b/projects/sam_inference_demo/sam/modeling/transformer.py
new file mode 100644
index 00000000000..c56f602487d
--- /dev/null
+++ b/projects/sam_inference_demo/sam/modeling/transformer.py
@@ -0,0 +1,241 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import Tuple, Type
+
+import torch
+from torch import Tensor, nn
+
+from mmseg.registry import MODELS
+from .common import MLPBlock
+
+
+@MODELS.register_module()
+class TwoWayTransformer(nn.Module):
+
+ def __init__(
+ self,
+ depth: int,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int,
+ activation: Type[nn.Module] = nn.ReLU,
+ attention_downsample_rate: int = 2,
+ ) -> None:
+ """A transformer decoder that attends to an input image using queries
+ whose positional embedding is supplied.
+
+ Args:
+ depth (int): number of layers in the transformer
+ embedding_dim (int): the channel dimension for the input embeddings
+ num_heads (int): the number of heads for multihead attention. Must
+ divide embedding_dim
+ mlp_dim (int): the channel dimension internal to the MLP block
+ activation (nn.Module): the activation to use in the MLP block
+ """
+ super().__init__()
+ self.depth = depth
+ self.embedding_dim = embedding_dim
+ self.num_heads = num_heads
+ self.mlp_dim = mlp_dim
+ self.layers = nn.ModuleList()
+
+ for i in range(depth):
+ self.layers.append(
+ TwoWayAttentionBlock(
+ embedding_dim=embedding_dim,
+ num_heads=num_heads,
+ mlp_dim=mlp_dim,
+ activation=activation,
+ attention_downsample_rate=attention_downsample_rate,
+ skip_first_layer_pe=(i == 0),
+ ))
+
+ self.final_attn_token_to_image = Attention(
+ embedding_dim,
+ num_heads,
+ downsample_rate=attention_downsample_rate)
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
+
+ def forward(
+ self,
+ image_embedding: Tensor,
+ image_pe: Tensor,
+ point_embedding: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ image_embedding (torch.Tensor): image to attend to. Should be shape
+ B x embedding_dim x h x w for any h and w.
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
+ have the same shape as image_embedding.
+ point_embedding (torch.Tensor): the embedding to add to the query points.
+ Must have shape B x N_points x embedding_dim for any N_points.
+
+ Returns:
+ torch.Tensor: the processed point_embedding
+ torch.Tensor: the processed image_embedding
+ """ # noqa E501
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
+ bs, c, h, w = image_embedding.shape
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
+
+ # Prepare queries
+ queries = point_embedding
+ keys = image_embedding
+
+ # Apply transformer blocks and final layernorm
+ for layer in self.layers:
+ queries, keys = layer(
+ queries=queries,
+ keys=keys,
+ query_pe=point_embedding,
+ key_pe=image_pe,
+ )
+
+ # Apply the final attenion layer from the points to the image
+ q = queries + point_embedding
+ k = keys + image_pe
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
+ queries = queries + attn_out
+ queries = self.norm_final_attn(queries)
+
+ return queries, keys
+
+
+class TwoWayAttentionBlock(nn.Module):
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int = 2048,
+ activation: Type[nn.Module] = nn.ReLU,
+ attention_downsample_rate: int = 2,
+ skip_first_layer_pe: bool = False,
+ ) -> None:
+ """A transformer block with four layers: (1) self-attention of sparse
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
+ block on sparse inputs, and (4) cross attention of dense inputs to
+ sparse inputs.
+
+ Arguments:
+ embedding_dim (int): the channel dimension of the embeddings
+ num_heads (int): the number of heads in the attention layers
+ mlp_dim (int): the hidden dimension of the mlp block
+ activation (nn.Module): the activation of the mlp block
+ skip_first_layer_pe (bool): skip the PE on the first layer
+ """
+ super().__init__()
+ self.self_attn = Attention(embedding_dim, num_heads)
+ self.norm1 = nn.LayerNorm(embedding_dim)
+
+ self.cross_attn_token_to_image = Attention(
+ embedding_dim,
+ num_heads,
+ downsample_rate=attention_downsample_rate)
+ self.norm2 = nn.LayerNorm(embedding_dim)
+
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
+ self.norm3 = nn.LayerNorm(embedding_dim)
+
+ self.norm4 = nn.LayerNorm(embedding_dim)
+ self.cross_attn_image_to_token = Attention(
+ embedding_dim,
+ num_heads,
+ downsample_rate=attention_downsample_rate)
+
+ self.skip_first_layer_pe = skip_first_layer_pe
+
+ def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor,
+ key_pe: Tensor) -> Tuple[Tensor, Tensor]:
+ # Self attention block
+ if self.skip_first_layer_pe:
+ queries = self.self_attn(q=queries, k=queries, v=queries)
+ else:
+ q = queries + query_pe
+ attn_out = self.self_attn(q=q, k=q, v=queries)
+ queries = queries + attn_out
+ queries = self.norm1(queries)
+
+ # Cross attention block, tokens attending to image embedding
+ q = queries + query_pe
+ k = keys + key_pe
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
+ queries = queries + attn_out
+ queries = self.norm2(queries)
+
+ # MLP block
+ mlp_out = self.mlp(queries)
+ queries = queries + mlp_out
+ queries = self.norm3(queries)
+
+ # Cross attention block, image embedding attending to tokens
+ q = queries + query_pe
+ k = keys + key_pe
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
+ keys = keys + attn_out
+ keys = self.norm4(keys)
+
+ return queries, keys
+
+
+class Attention(nn.Module):
+ """An attention layer that allows for downscaling the size of the embedding
+ after projection to queries, keys, and values."""
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ downsample_rate: int = 1,
+ ) -> None:
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.internal_dim = embedding_dim // downsample_rate
+ self.num_heads = num_heads
+ assert self.internal_dim % num_heads == 0, 'num_heads must divide embedding_dim.' # noqa E501
+
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
+
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
+ b, n, c = x.shape
+ x = x.reshape(b, n, num_heads, c // num_heads)
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
+
+ def _recombine_heads(self, x: Tensor) -> Tensor:
+ b, n_heads, n_tokens, c_per_head = x.shape
+ x = x.transpose(1, 2)
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
+
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
+ # Input projections
+ q = self.q_proj(q)
+ k = self.k_proj(k)
+ v = self.v_proj(v)
+
+ # Separate into heads
+ q = self._separate_heads(q, self.num_heads)
+ k = self._separate_heads(k, self.num_heads)
+ v = self._separate_heads(v, self.num_heads)
+
+ # Attention
+ _, _, _, c_per_head = q.shape
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
+ attn = attn / math.sqrt(c_per_head)
+ attn = torch.softmax(attn, dim=-1)
+
+ # Get output
+ out = attn @ v
+ out = self._recombine_heads(out)
+ out = self.out_proj(out)
+
+ return out
diff --git a/projects/sam_inference_demo/sam/sam_inferencer.py b/projects/sam_inference_demo/sam/sam_inferencer.py
new file mode 100644
index 00000000000..2da2e959c09
--- /dev/null
+++ b/projects/sam_inference_demo/sam/sam_inferencer.py
@@ -0,0 +1,688 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Any, Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+from mmengine.runner.checkpoint import load_checkpoint
+# yapf: disable
+from sam.utils import (MaskData, area_from_rle, batch_iterator,
+ batched_mask_to_box, box_xyxy_to_xywh,
+ build_all_layer_point_grids, calculate_stability_score,
+ coco_encode_rle, generate_crop_boxes,
+ is_box_near_crop_edge, mask_to_rle_pytorch,
+ remove_small_regions, rle_to_mask, uncrop_boxes_xyxy,
+ uncrop_masks, uncrop_points)
+from torchvision.ops.boxes import batched_nms, box_area
+
+from mmseg.registry import MODELS, TRANSFORMS
+
+# yapf: enable
+
+model_zoo = {
+ 'base':
+ 'https://download.openmmlab.com/mmsegmentation/v0.5/sam/sam_vit-base-p16_3rdparty_sa1b-1024x1024_20230413-78a25eed.pth', # noqa
+ 'large':
+ 'https://download.openmmlab.com/mmsegmentation/v0.5/sam/sam_vit-large-p16_3rdparty_sa1b-1024x1024_20230413-940520da.pth', # noqa
+ 'huge':
+ 'https://download.openmmlab.com/mmsegmentation/v0.5/sam/sam_vit-huge-p16_3rdparty_sa1b-1024x1024_20230413-faaf96f6.pth', # noqa
+}
+
+
+class SAMInferencer:
+
+ def __init__(self, arch: str = 'base') -> None:
+ assert arch in ['base', 'large', 'huge']
+ self.model = self.init_model(arch)
+ self.transform = TRANSFORMS.build(
+ dict(
+ type='ResizeLongestSide',
+ target_length=max(self.model.image_encoder.img_size)))
+
+ def set_image(
+ self,
+ image: np.ndarray,
+ image_format: str = 'RGB',
+ ) -> None:
+ """Calculates the image embeddings for the provided image, allowing
+ masks to be predicted with the 'predict' method.
+
+ Arguments:
+ image (np.ndarray): The image for calculating masks. Expects an
+ image in HWC uint8 format, with pixel values in [0, 255].
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
+ """
+ assert image_format in [
+ 'RGB',
+ 'BGR',
+ ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
+ if image_format != self.model.image_format:
+ image = image[..., ::-1]
+
+ # Transform the image to the form expected by the model
+ input_image = self.transform.apply_image(image)
+ input_image_torch = torch.as_tensor(input_image, device=self.device)
+ input_image_torch = input_image_torch.permute(
+ 2, 0, 1).contiguous()[None, :, :, :]
+
+ self.set_torch_image(input_image_torch, image.shape[:2])
+
+ @torch.no_grad()
+ def set_torch_image(
+ self,
+ transformed_image: torch.Tensor,
+ original_image_size: Tuple[int, ...],
+ ) -> None:
+ """Calculates the image embeddings for the provided image, allowing
+ masks to be predicted with the 'predict' method. Expects the input
+ image to be already transformed to the format expected by the model.
+
+ Arguments:
+ transformed_image (torch.Tensor): The input image, with shape
+ 1x3xHxW, which has been transformed with ResizeLongestSide.
+ original_image_size (tuple(int, int)): The size of the image
+ before transformation, in (H, W) format.
+ """
+ assert (len(transformed_image.shape) == 4
+ and transformed_image.shape[1] == 3
+ and max(*transformed_image.shape[2:]) == max(
+ self.model.image_encoder.img_size)
+ ), 'set_torch_image input must be BCHW with long side'
+ f' {self.model.image_encoder.img_size}.'
+ self.reset_image()
+
+ self.original_size = original_image_size
+ self.input_size = tuple(transformed_image.shape[-2:])
+ input_image = self.model.preprocess(transformed_image)
+ self.features = self.model.image_encoder(input_image)[0]
+ self.is_image_set = True
+
+ def predict(
+ self,
+ point_coords: Optional[np.ndarray] = None,
+ point_labels: Optional[np.ndarray] = None,
+ box: Optional[np.ndarray] = None,
+ mask_input: Optional[np.ndarray] = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Predict masks for the given input prompts, using the currently set
+ image.
+
+ Arguments:
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels.
+ point_labels (np.ndarray or None): A length N array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a
+ background point.
+ box (np.ndarray or None): A length 4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form 1xHxW, where
+ for SAM, H=W=256.
+ multimask_output (bool): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+
+ Returns:
+ (np.ndarray): The output masks in CxHxW format, where C is the
+ number of masks, and (H, W) is the original image size.
+ (np.ndarray): An array of length C containing the model's
+ predictions for the quality of each mask.
+ (np.ndarray): An array of shape CxHxW, where C is the number
+ of masks and H=W=256. These low resolution logits can be passed to
+ a subsequent iteration as mask input.
+ """ # noqa
+ if not self.is_image_set:
+ raise RuntimeError(
+ 'An image must be set with .set_image(...) before mask'
+ 'prediction.')
+
+ # Transform input prompts
+ coords_torch = None
+ labels_torch = None
+ box_torch = None
+ mask_input_torch = None
+
+ if point_coords is not None:
+ assert (
+ point_labels is not None
+ ), 'point_labels must be supplied if point_coords is supplied.'
+ point_coords = self.transform.apply_coords(point_coords,
+ self.original_size)
+ coords_torch = torch.as_tensor(
+ point_coords, dtype=torch.float, device=self.device)
+ labels_torch = torch.as_tensor(
+ point_labels, dtype=torch.int, device=self.device)
+ coords_torch, labels_torch = coords_torch[
+ None, :, :], labels_torch[None, :]
+ if box is not None:
+ box = self.transform.apply_boxes(box, self.original_size)
+ box_torch = torch.as_tensor(
+ box, dtype=torch.float, device=self.device)
+ box_torch = box_torch[None, :]
+ if mask_input is not None:
+ mask_input_torch = torch.as_tensor(
+ mask_input, dtype=torch.float, device=self.device)
+ mask_input_torch = mask_input_torch[None, :, :, :]
+
+ masks, iou_predictions, low_res_masks = self.predict_torch(
+ coords_torch,
+ labels_torch,
+ box_torch,
+ mask_input_torch,
+ multimask_output,
+ return_logits=return_logits,
+ )
+
+ masks = masks[0].detach().cpu().numpy()
+ iou_predictions = iou_predictions[0].detach().cpu().numpy()
+ low_res_masks = low_res_masks[0].detach().cpu().numpy()
+ return masks, iou_predictions, low_res_masks
+
+ @torch.no_grad()
+ def predict_torch(
+ self,
+ point_coords: Optional[torch.Tensor],
+ point_labels: Optional[torch.Tensor],
+ boxes: Optional[torch.Tensor] = None,
+ mask_input: Optional[torch.Tensor] = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Predict masks for the given input prompts, using the currently set
+ image. Input prompts are batched torch tensors and are expected to
+ already be transformed to the input frame using ResizeLongestSide.
+
+ Arguments:
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels.
+ point_labels (torch.Tensor or None): A BxN array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a
+ background point.
+ box (np.ndarray or None): A Bx4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
+ for SAM, H=W=256. Masks returned by a previous iteration of the
+ predict method do not need further transformation.
+ multimask_output (bool): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+
+ Returns:
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
+ number of masks, and (H, W) is the original image size.
+ (torch.Tensor): An array of shape BxC containing the model's
+ predictions for the quality of each mask.
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
+ of masks and H=W=256. These low res logits can be passed to
+ a subsequent iteration as mask input.
+ """ # noqa
+ if not self.is_image_set:
+ raise RuntimeError(
+ 'An image must be set with .set_image(...) before mask '
+ 'prediction.')
+
+ if point_coords is not None:
+ points = (point_coords, point_labels)
+ else:
+ points = None
+
+ # Embed prompts
+ sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
+ points=points,
+ boxes=boxes,
+ masks=mask_input,
+ )
+
+ # Predict masks
+ low_res_masks, iou_predictions = self.model.mask_decoder(
+ image_embeddings=self.features,
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ )
+
+ # Upscale the masks to the original image resolution
+ masks = self.model.postprocess_masks(low_res_masks, self.input_size,
+ self.original_size)
+
+ if not return_logits:
+ masks = masks > self.model.mask_threshold
+
+ return masks, iou_predictions, low_res_masks
+
+ def get_image_embedding(self) -> torch.Tensor:
+ """Returns the image embeddings for the currently set image, with shape
+ 1xCxHxW, where C is the embedding dimension and (H,W) are the embedding
+ spatial dimension of SAM (typically C=256, H=W=64)."""
+ if not self.is_image_set:
+ raise RuntimeError(
+ 'An image must be set with .set_image(...) to generate an '
+ 'embedding.')
+ assert self.features is not None, 'Features must exist if an image has'
+ ' been set.'
+ return self.features
+
+ @property
+ def device(self) -> torch.device:
+ return self.model.device
+
+ def reset_image(self) -> None:
+ """Resets the currently set image."""
+ self.is_image_set = False
+ self.features = None
+ self.orig_h = None
+ self.orig_w = None
+ self.input_h = None
+ self.input_w = None
+
+ def init_model(self, arch: str):
+ model = MODELS.build(
+ dict(
+ type='SAM',
+ image_encoder_cfg=dict(
+ type='mmpretrain.ViTSAM',
+ arch=arch,
+ img_size=1024,
+ patch_size=16,
+ out_channels=256,
+ use_abs_pos=True,
+ use_rel_pos=True,
+ window_size=14,
+ ),
+ prompt_encoder_cfg=dict(
+ type='PromptEncoder',
+ embed_dim=256,
+ image_embedding_size=(64, 64),
+ input_image_size=(1024, 1024),
+ mask_in_chans=16,
+ ),
+ mask_decoder_cfg=dict(
+ type='MaskDecoder',
+ num_multimask_outputs=3,
+ transformer=dict(
+ type='TwoWayTransformer',
+ depth=2,
+ embedding_dim=256,
+ mlp_dim=2048,
+ num_heads=8,
+ ),
+ transformer_dim=256,
+ iou_head_depth=3,
+ iou_head_hidden_dim=256,
+ )))
+ load_checkpoint(model, model_zoo.get(arch), strict=True)
+ if torch.cuda.is_available():
+ model = model.cuda()
+ return model
+
+
+class SamAutomaticMaskGenerator:
+
+ def __init__(
+ self,
+ arch: str = 'base',
+ points_per_side: Optional[int] = 32,
+ points_per_batch: int = 64,
+ pred_iou_thresh: float = 0.88,
+ stability_score_thresh: float = 0.95,
+ stability_score_offset: float = 1.0,
+ box_nms_thresh: float = 0.7,
+ crop_n_layers: int = 0,
+ crop_nms_thresh: float = 0.7,
+ crop_overlap_ratio: float = 512 / 1500,
+ crop_n_points_downscale_factor: int = 1,
+ point_grids: Optional[List[np.ndarray]] = None,
+ min_mask_region_area: int = 0,
+ output_mode: str = 'binary_mask',
+ ) -> None:
+ """Using a SAM model, generates masks for the entire image. Generates a
+ grid of point prompts over the image, then filters low quality and
+ duplicate masks. The default settings are chosen for SAM with a ViT-H
+ backbone.
+
+ Arguments:
+ arch (str): The SAM model to use for mask prediction.
+ points_per_side (int or None): The number of points to be sampled
+ along one side of the image. The total number of points is
+ points_per_side**2. If None, 'point_grids' must provide explicit
+ point sampling.
+ points_per_batch (int): Sets the number of points run simultaneously
+ by the model. Higher numbers may be faster but use more GPU memory.
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
+ model's predicted mask quality.
+ stability_score_thresh (float): A filtering threshold in [0,1], using
+ the stability of the mask under changes to the cutoff used to binarize
+ the model's mask predictions.
+ stability_score_offset (float): The amount to shift the cutoff when
+ calculated the stability score.
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
+ suppression to filter duplicate masks.
+ crops_n_layers (int): If >0, mask prediction will be run again on
+ crops of the image. Sets the number of layers to run, where each
+ layer has 2**i_layer number of image crops.
+ crops_nms_thresh (float): The box IoU cutoff used by non-maximal
+ suppression to filter duplicate masks between different crops.
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
+ In the first crop layer, crops will overlap by this fraction of
+ the image length. Later layers with more crops scale down this overlap.
+ crop_n_points_downscale_factor (int): The number of points-per-side
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
+ point_grids (list(np.ndarray) or None): A list over explicit grids
+ of points used for sampling, normalized to [0,1]. The nth grid in the
+ list is used in the nth crop layer. Exclusive with points_per_side.
+ min_mask_region_area (int): If >0, postprocessing will be applied
+ to remove disconnected regions and holes in masks with area smaller
+ than min_mask_region_area. Requires opencv.
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
+ For large resolutions, 'binary_mask' may consume large amounts of
+ memory.
+ """ # noqa
+
+ assert (points_per_side is None) != (
+ point_grids is None
+ ), 'Exactly one of points_per_side or point_grid must be provided.'
+ if points_per_side is not None:
+ self.point_grids = build_all_layer_point_grids(
+ points_per_side,
+ crop_n_layers,
+ crop_n_points_downscale_factor,
+ )
+ elif point_grids is not None:
+ self.point_grids = point_grids
+ else:
+ raise ValueError(
+ "Can't have both points_per_side and point_grid be None.")
+
+ assert output_mode in [
+ 'binary_mask',
+ 'uncompressed_rle',
+ 'coco_rle',
+ ], f'Unknown output_mode {output_mode}.'
+ if output_mode == 'coco_rle':
+ from pycocotools import \
+ mask as mask_utils # type: ignore # noqa: F401
+
+ if min_mask_region_area > 0:
+ import cv2 # type: ignore # noqa: F401
+
+ self.predictor = SAMInferencer(arch)
+ self.points_per_batch = points_per_batch
+ self.pred_iou_thresh = pred_iou_thresh
+ self.stability_score_thresh = stability_score_thresh
+ self.stability_score_offset = stability_score_offset
+ self.box_nms_thresh = box_nms_thresh
+ self.crop_n_layers = crop_n_layers
+ self.crop_nms_thresh = crop_nms_thresh
+ self.crop_overlap_ratio = crop_overlap_ratio
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
+ self.min_mask_region_area = min_mask_region_area
+ self.output_mode = output_mode
+
+ @torch.no_grad()
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
+ """Generates masks for the given image.
+
+ Arguments:
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
+
+ Returns:
+ list(dict(str, any)): A list over records for masks. Each record is
+ a dict containing the following keys:
+ segmentation (dict(str, any) or np.ndarray): The mask. If
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
+ is a dictionary containing the RLE.
+ bbox (list(float)): The box around the mask, in XYWH format.
+ area (int): The area in pixels of the mask.
+ predicted_iou (float): The model's own prediction of the mask's
+ quality. This is filtered by the pred_iou_thresh parameter.
+ point_coords (list(list(float))): The point coordinates input
+ to the model to generate this mask.
+ stability_score (float): A measure of the mask's quality. This
+ is filtered on using the stability_score_thresh parameter.
+ crop_box (list(float)): The crop of the image used to generate
+ the mask, given in XYWH format.
+ """ # noqa
+
+ # Generate masks
+ mask_data = self._generate_masks(image)
+
+ # Filter small disconnected regions and holes in masks
+ if self.min_mask_region_area > 0:
+ mask_data = self.postprocess_small_regions(
+ mask_data,
+ self.min_mask_region_area,
+ max(self.box_nms_thresh, self.crop_nms_thresh),
+ )
+
+ # Encode masks
+ if self.output_mode == 'coco_rle':
+ mask_data['segmentations'] = [
+ coco_encode_rle(rle) for rle in mask_data['rles']
+ ]
+ elif self.output_mode == 'binary_mask':
+ mask_data['segmentations'] = [
+ rle_to_mask(rle) for rle in mask_data['rles']
+ ]
+ else:
+ mask_data['segmentations'] = mask_data['rles']
+
+ # Write mask records
+ curr_anns = []
+ for idx in range(len(mask_data['segmentations'])):
+ ann = {
+ 'segmentation':
+ mask_data['segmentations'][idx],
+ 'area':
+ area_from_rle(mask_data['rles'][idx]),
+ 'bbox':
+ box_xyxy_to_xywh(mask_data['boxes'][idx]).tolist(),
+ 'predicted_iou':
+ mask_data['iou_preds'][idx].item(),
+ 'point_coords': [mask_data['points'][idx].tolist()],
+ 'stability_score':
+ mask_data['stability_score'][idx].item(),
+ 'crop_box':
+ box_xyxy_to_xywh(mask_data['crop_boxes'][idx]).tolist(),
+ }
+ curr_anns.append(ann)
+
+ return curr_anns
+
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
+ orig_size = image.shape[:2]
+ crop_boxes, layer_idxs = generate_crop_boxes(orig_size,
+ self.crop_n_layers,
+ self.crop_overlap_ratio)
+
+ # Iterate over image crops
+ data = MaskData()
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
+ crop_data = self._process_crop(image, crop_box, layer_idx,
+ orig_size)
+ data.cat(crop_data)
+
+ # Remove duplicate masks between crops
+ if len(crop_boxes) > 1:
+ # Prefer masks from smaller crops
+ scores = 1 / box_area(data['crop_boxes'])
+ scores = scores.to(data['boxes'].device)
+ keep_by_nms = batched_nms(
+ data['boxes'].float(),
+ scores,
+ torch.zeros(len(data['boxes'])), # categories
+ iou_threshold=self.crop_nms_thresh,
+ )
+ data.filter(keep_by_nms)
+
+ data.to_numpy()
+ return data
+
+ def _process_crop(
+ self,
+ image: np.ndarray,
+ crop_box: List[int],
+ crop_layer_idx: int,
+ orig_size: Tuple[int, ...],
+ ) -> MaskData:
+ # Crop the image and calculate embeddings
+ x0, y0, x1, y1 = crop_box
+ cropped_im = image[y0:y1, x0:x1, :]
+ cropped_im_size = cropped_im.shape[:2]
+ self.predictor.set_image(cropped_im)
+
+ # Get points for this crop
+ points_scale = np.array(cropped_im_size)[None, ::-1]
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
+
+ # Generate masks for this crop in batches
+ data = MaskData()
+ for (points, ) in batch_iterator(self.points_per_batch,
+ points_for_image):
+ batch_data = self._process_batch(points, cropped_im_size, crop_box,
+ orig_size)
+ data.cat(batch_data)
+ del batch_data
+ self.predictor.reset_image()
+
+ # Remove duplicates within this crop.
+ keep_by_nms = batched_nms(
+ data['boxes'].float(),
+ data['iou_preds'],
+ torch.zeros(len(data['boxes'])), # categories
+ iou_threshold=self.box_nms_thresh,
+ )
+ data.filter(keep_by_nms)
+
+ # Return to the original image frame
+ data['boxes'] = uncrop_boxes_xyxy(data['boxes'], crop_box)
+ data['points'] = uncrop_points(data['points'], crop_box)
+ data['crop_boxes'] = torch.tensor(
+ [crop_box for _ in range(len(data['rles']))])
+
+ return data
+
+ def _process_batch(
+ self,
+ points: np.ndarray,
+ im_size: Tuple[int, ...],
+ crop_box: List[int],
+ orig_size: Tuple[int, ...],
+ ) -> MaskData:
+ orig_h, orig_w = orig_size
+
+ # Run model on this batch
+ transformed_points = self.predictor.transform.apply_coords(
+ points, im_size)
+ in_points = torch.as_tensor(
+ transformed_points, device=self.predictor.device)
+ in_labels = torch.ones(
+ in_points.shape[0], dtype=torch.int, device=in_points.device)
+ masks, iou_preds, _ = self.predictor.predict_torch(
+ in_points[:, None, :],
+ in_labels[:, None],
+ multimask_output=True,
+ return_logits=True,
+ )
+
+ # Serialize predictions and store in MaskData
+ data = MaskData(
+ masks=masks.flatten(0, 1),
+ iou_preds=iou_preds.flatten(0, 1),
+ points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
+ )
+ del masks
+
+ # Filter by predicted IoU
+ if self.pred_iou_thresh > 0.0:
+ keep_mask = data['iou_preds'] > self.pred_iou_thresh
+ data.filter(keep_mask)
+
+ # Calculate stability score
+ data['stability_score'] = calculate_stability_score(
+ data['masks'], self.predictor.model.mask_threshold,
+ self.stability_score_offset)
+ if self.stability_score_thresh > 0.0:
+ keep_mask = data['stability_score'] >= self.stability_score_thresh
+ data.filter(keep_mask)
+
+ # Threshold masks and calculate boxes
+ data['masks'] = data['masks'] > self.predictor.model.mask_threshold
+ data['boxes'] = batched_mask_to_box(data['masks'])
+
+ # Filter boxes that touch crop boundaries
+ keep_mask = ~is_box_near_crop_edge(data['boxes'], crop_box,
+ [0, 0, orig_w, orig_h])
+ if not torch.all(keep_mask):
+ data.filter(keep_mask)
+
+ # Compress to RLE
+ data['masks'] = uncrop_masks(data['masks'], crop_box, orig_h, orig_w)
+ data['rles'] = mask_to_rle_pytorch(data['masks'])
+ del data['masks']
+
+ return data
+
+ @staticmethod
+ def postprocess_small_regions(mask_data: MaskData, min_area: int,
+ nms_thresh: float) -> MaskData:
+ """Removes small disconnected regions and holes in masks, then reruns
+ box NMS to remove any new duplicates.
+
+ Edits mask_data in place.
+
+ Requires open-cv as a dependency.
+ """
+ if len(mask_data['rles']) == 0:
+ return mask_data
+
+ # Filter small disconnected regions and holes
+ new_masks = []
+ scores = []
+ for rle in mask_data['rles']:
+ mask = rle_to_mask(rle)
+
+ mask, changed = remove_small_regions(mask, min_area, mode='holes')
+ unchanged = not changed
+ mask, changed = remove_small_regions(
+ mask, min_area, mode='islands')
+ unchanged = unchanged and not changed
+
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
+ # Give score=0 to changed masks and score=1 to unchanged masks
+ # so NMS will prefer ones that didn't need postprocessing
+ scores.append(float(unchanged))
+
+ # Recalculate boxes and remove any new duplicates
+ masks = torch.cat(new_masks, dim=0)
+ boxes = batched_mask_to_box(masks)
+ keep_by_nms = batched_nms(
+ boxes.float(),
+ torch.as_tensor(scores),
+ torch.zeros(len(boxes)), # categories
+ iou_threshold=nms_thresh,
+ )
+
+ # Only recalculate RLEs for masks that have changed
+ for i_mask in keep_by_nms:
+ if scores[i_mask] == 0.0:
+ mask_torch = masks[i_mask].unsqueeze(0)
+ mask_data['rles'][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
+ mask_data['boxes'][i_mask] = boxes[
+ i_mask] # update res directly
+ mask_data.filter(keep_by_nms)
+
+ return mask_data
diff --git a/projects/sam_inference_demo/sam/utils/__init__.py b/projects/sam_inference_demo/sam/utils/__init__.py
new file mode 100644
index 00000000000..5d33e33aeef
--- /dev/null
+++ b/projects/sam_inference_demo/sam/utils/__init__.py
@@ -0,0 +1,2 @@
+from .amg import * # noqa: F403 F401
+from .transforms import ResizeLongestSide # noqa: F403 F401
diff --git a/projects/sam_inference_demo/sam/utils/amg.py b/projects/sam_inference_demo/sam/utils/amg.py
new file mode 100644
index 00000000000..3ba359901f7
--- /dev/null
+++ b/projects/sam_inference_demo/sam/utils/amg.py
@@ -0,0 +1,355 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# https://github.com/facebookresearch/segment-anything
+
+import math
+from copy import deepcopy
+from itertools import product
+from typing import Any, Dict, Generator, ItemsView, List, Tuple
+
+import numpy as np
+import torch
+
+
+class MaskData:
+ """A structure for storing masks and their related data in batched format.
+
+ Implements basic filtering and concatenation.
+ """
+
+ def __init__(self, **kwargs) -> None:
+ for v in kwargs.values():
+ assert isinstance(
+ v, (list, np.ndarray, torch.Tensor)
+ ), 'MaskData only supports list, numpy arrays, and torch tensors.'
+ self._stats = dict(**kwargs)
+
+ def __setitem__(self, key: str, item: Any) -> None:
+ assert isinstance(
+ item, (list, np.ndarray, torch.Tensor)
+ ), 'MaskData only supports list, numpy arrays, and torch tensors.'
+ self._stats[key] = item
+
+ def __delitem__(self, key: str) -> None:
+ del self._stats[key]
+
+ def __getitem__(self, key: str) -> Any:
+ return self._stats[key]
+
+ def items(self) -> ItemsView[str, Any]:
+ return self._stats.items()
+
+ def filter(self, keep: torch.Tensor) -> None:
+ for k, v in self._stats.items():
+ if v is None:
+ self._stats[k] = None
+ elif isinstance(v, torch.Tensor):
+ self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
+ elif isinstance(v, np.ndarray):
+ self._stats[k] = v[keep.detach().cpu().numpy()]
+ elif isinstance(v, list) and keep.dtype == torch.bool:
+ self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
+ elif isinstance(v, list):
+ self._stats[k] = [v[i] for i in keep]
+ else:
+ raise TypeError(
+ f'MaskData key {k} has an unsupported type {type(v)}.')
+
+ def cat(self, new_stats: 'MaskData') -> None:
+ for k, v in new_stats.items():
+ if k not in self._stats or self._stats[k] is None:
+ self._stats[k] = deepcopy(v)
+ elif isinstance(v, torch.Tensor):
+ self._stats[k] = torch.cat([self._stats[k], v], dim=0)
+ elif isinstance(v, np.ndarray):
+ self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
+ elif isinstance(v, list):
+ self._stats[k] = self._stats[k] + deepcopy(v)
+ else:
+ raise TypeError(
+ f'MaskData key {k} has an unsupported type {type(v)}.')
+
+ def to_numpy(self) -> None:
+ for k, v in self._stats.items():
+ if isinstance(v, torch.Tensor):
+ self._stats[k] = v.detach().cpu().numpy()
+
+
+def is_box_near_crop_edge(boxes: torch.Tensor,
+ crop_box: List[int],
+ orig_box: List[int],
+ atol: float = 20.0) -> torch.Tensor:
+ """Filter masks at the edge of a crop, but not at the edge of the original
+ image."""
+ crop_box_torch = torch.as_tensor(
+ crop_box, dtype=torch.float, device=boxes.device)
+ orig_box_torch = torch.as_tensor(
+ orig_box, dtype=torch.float, device=boxes.device)
+ boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
+ near_crop_edge = torch.isclose(
+ boxes, crop_box_torch[None, :], atol=atol, rtol=0)
+ near_image_edge = torch.isclose(
+ boxes, orig_box_torch[None, :], atol=atol, rtol=0)
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
+ return torch.any(near_crop_edge, dim=1)
+
+
+def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
+ box_xywh = deepcopy(box_xyxy)
+ box_xywh[2] = box_xywh[2] - box_xywh[0]
+ box_xywh[3] = box_xywh[3] - box_xywh[1]
+ return box_xywh
+
+
+def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
+ assert len(args) > 0 and all(
+ len(a) == len(args[0]) for a in
+ args), 'Batched iteration must have inputs of all the same size.'
+ n_batches = len(args[0]) // batch_size + int(
+ len(args[0]) % batch_size != 0)
+ for b in range(n_batches):
+ yield [arg[b * batch_size:(b + 1) * batch_size] for arg in args]
+
+
+def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
+ """Encodes masks to an uncompressed RLE, in the format expected by pycoco
+ tools."""
+ # Put in fortran order and flatten h,w
+ b, h, w = tensor.shape
+ tensor = tensor.permute(0, 2, 1).flatten(1)
+
+ # Compute change indices
+ diff = tensor[:, 1:] ^ tensor[:, :-1]
+ change_indices = diff.nonzero()
+
+ # Encode run length
+ out = []
+ for i in range(b):
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1]
+ cur_idxs = torch.cat([
+ torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
+ cur_idxs + 1,
+ torch.tensor([h * w], dtype=cur_idxs.dtype,
+ device=cur_idxs.device),
+ ])
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
+ counts = [] if tensor[i, 0] == 0 else [0]
+ counts.extend(btw_idxs.detach().cpu().tolist())
+ out.append({'size': [h, w], 'counts': counts})
+ return out
+
+
+def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
+ """Compute a binary mask from an uncompressed RLE."""
+ h, w = rle['size']
+ mask = np.empty(h * w, dtype=bool)
+ idx = 0
+ parity = False
+ for count in rle['counts']:
+ mask[idx:idx + count] = parity
+ idx += count
+ parity ^= True
+ mask = mask.reshape(w, h)
+ return mask.transpose() # Put in C order
+
+
+def area_from_rle(rle: Dict[str, Any]) -> int:
+ return sum(rle['counts'][1::2])
+
+
+def calculate_stability_score(masks: torch.Tensor, mask_threshold: float,
+ threshold_offset: float) -> torch.Tensor:
+ """Computes the stability score for a batch of masks.
+
+ The stability score is the IoU between the binary masks obtained by
+ thresholding the predicted mask logits at high and low values.
+ """
+ # One mask is always contained inside the other.
+ # Save memory by preventing unnecessary cast to torch.int64
+ intersections = ((masks > (mask_threshold + threshold_offset)).sum(
+ -1, dtype=torch.int16).sum(-1, dtype=torch.int32))
+ unions = ((masks > (mask_threshold - threshold_offset)).sum(
+ -1, dtype=torch.int16).sum(-1, dtype=torch.int32))
+ return intersections / unions
+
+
+def build_point_grid(n_per_side: int) -> np.ndarray:
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
+ offset = 1 / (2 * n_per_side)
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
+ points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
+ return points
+
+
+def build_all_layer_point_grids(n_per_side: int, n_layers: int,
+ scale_per_layer: int) -> List[np.ndarray]:
+ """Generates point grids for all crop layers."""
+ points_by_layer = []
+ for i in range(n_layers + 1):
+ n_points = int(n_per_side / (scale_per_layer**i))
+ points_by_layer.append(build_point_grid(n_points))
+ return points_by_layer
+
+
+def generate_crop_boxes(
+ im_size: Tuple[int, ...], n_layers: int,
+ overlap_ratio: float) -> Tuple[List[List[int]], List[int]]:
+ """Generates a list of crop boxes of different sizes.
+
+ Each layer has (2**i)**2 boxes for the ith layer.
+ """
+ crop_boxes, layer_idxs = [], []
+ im_h, im_w = im_size
+ short_side = min(im_h, im_w)
+
+ # Original image
+ crop_boxes.append([0, 0, im_w, im_h])
+ layer_idxs.append(0)
+
+ def crop_len(orig_len, n_crops, overlap):
+ return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
+
+ for i_layer in range(n_layers):
+ n_crops_per_side = 2**(i_layer + 1)
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
+
+ crop_w = crop_len(im_w, n_crops_per_side, overlap)
+ crop_h = crop_len(im_h, n_crops_per_side, overlap)
+
+ crop_box_x0 = [
+ int((crop_w - overlap) * i) for i in range(n_crops_per_side)
+ ]
+ crop_box_y0 = [
+ int((crop_h - overlap) * i) for i in range(n_crops_per_side)
+ ]
+
+ # Crops in XYWH format
+ for x0, y0 in product(crop_box_x0, crop_box_y0):
+ box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
+ crop_boxes.append(box)
+ layer_idxs.append(i_layer + 1)
+
+ return crop_boxes, layer_idxs
+
+
+def uncrop_boxes_xyxy(boxes: torch.Tensor,
+ crop_box: List[int]) -> torch.Tensor:
+ x0, y0, _, _ = crop_box
+ offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
+ # Check if boxes has a channel dimension
+ if len(boxes.shape) == 3:
+ offset = offset.unsqueeze(1)
+ return boxes + offset
+
+
+def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+ x0, y0, _, _ = crop_box
+ offset = torch.tensor([[x0, y0]], device=points.device)
+ # Check if points has a channel dimension
+ if len(points.shape) == 3:
+ offset = offset.unsqueeze(1)
+ return points + offset
+
+
+def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int,
+ orig_w: int) -> torch.Tensor:
+ x0, y0, x1, y1 = crop_box
+ if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
+ return masks
+ # Coordinate transform masks
+ pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
+ pad = (x0, pad_x - x0, y0, pad_y - y0)
+ return torch.nn.functional.pad(masks, pad, value=0)
+
+
+def remove_small_regions(mask: np.ndarray, area_thresh: float,
+ mode: str) -> Tuple[np.ndarray, bool]:
+ """Removes small disconnected regions and holes in a mask.
+
+ Returns the mask and an indicator of if the mask has been modified.
+ """
+ import cv2 # type: ignore
+
+ assert mode in ['holes', 'islands']
+ correct_holes = mode == 'holes'
+ working_mask = (correct_holes ^ mask).astype(np.uint8)
+ n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(
+ working_mask, 8)
+ sizes = stats[:, -1][1:] # Row 0 is background label
+ small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
+ if len(small_regions) == 0:
+ return mask, False
+ fill_labels = [0] + small_regions
+ if not correct_holes:
+ fill_labels = [i for i in range(n_labels) if i not in fill_labels]
+ # If every region is below threshold, keep largest
+ if len(fill_labels) == 0:
+ fill_labels = [int(np.argmax(sizes)) + 1]
+ mask = np.isin(regions, fill_labels)
+ return mask, True
+
+
+def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
+ from pycocotools import mask as mask_utils # type: ignore
+
+ h, w = uncompressed_rle['size']
+ rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
+ rle['counts'] = rle['counts'].decode(
+ 'utf-8') # Necessary to serialize with json
+ return rle
+
+
+def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
+ """Calculates boxes in XYXY format around masks.
+
+ Return [0,0,0,0] for an empty mask. For input shape C1xC2x...xHxW, the
+ output shape is C1xC2x...x4.
+ """
+ # torch.max below raises an error on empty inputs, just skip in this case
+ if torch.numel(masks) == 0:
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
+
+ # Normalize shape to CxHxW
+ shape = masks.shape
+ h, w = shape[-2:]
+ if len(shape) > 2:
+ masks = masks.flatten(0, -3)
+ else:
+ masks = masks.unsqueeze(0)
+
+ # Get top and bottom edges
+ in_height, _ = torch.max(masks, dim=-1)
+ in_height_coords = in_height * torch.arange(
+ h, device=in_height.device)[None, :]
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
+ in_height_coords = in_height_coords + h * (~in_height)
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
+
+ # Get left and right edges
+ in_width, _ = torch.max(masks, dim=-2)
+ in_width_coords = in_width * torch.arange(
+ w, device=in_width.device)[None, :]
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
+ in_width_coords = in_width_coords + w * (~in_width)
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
+
+ # If the mask is empty the right edge will be to the left of the left edge.
+ # Replace these boxes with [0, 0, 0, 0]
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges],
+ dim=-1)
+ out = out * (~empty_filter).unsqueeze(-1)
+
+ # Return to original shape
+ if len(shape) > 2:
+ out = out.reshape(*shape[:-2], 4)
+ else:
+ out = out[0]
+
+ return out
diff --git a/projects/sam_inference_demo/sam/utils/transforms.py b/projects/sam_inference_demo/sam/utils/transforms.py
new file mode 100644
index 00000000000..484fd6691cf
--- /dev/null
+++ b/projects/sam_inference_demo/sam/utils/transforms.py
@@ -0,0 +1,110 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from copy import deepcopy
+from typing import Tuple
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+from torchvision.transforms.functional import resize # type: ignore
+from torchvision.transforms.functional import to_pil_image
+
+from mmseg.registry import TRANSFORMS
+
+
+@TRANSFORMS.register_module()
+class ResizeLongestSide:
+ """Resizes images to longest side 'target_length', as well as provides
+ methods for resizing coordinates and boxes.
+
+ Provides methods for transforming both numpy array and batched torch
+ tensors.
+ """
+
+ def __init__(self, target_length: int) -> None:
+ self.target_length = target_length
+
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
+ """Expects a numpy array with shape HxWxC in uint8 format."""
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1],
+ self.target_length)
+ return np.array(resize(to_pil_image(image), target_size))
+
+ def apply_coords(self, coords: np.ndarray,
+ original_size: Tuple[int, ...]) -> np.ndarray:
+ """Expects a numpy array of length 2 in the final dimension.
+
+ Requires the original image size in (H, W) format.
+ """
+ old_h, old_w = original_size
+ new_h, new_w = self.get_preprocess_shape(original_size[0],
+ original_size[1],
+ self.target_length)
+ coords = deepcopy(coords).astype(float)
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
+ return coords
+
+ def apply_boxes(self, boxes: np.ndarray,
+ original_size: Tuple[int, ...]) -> np.ndarray:
+ """Expects a numpy array shape Bx4.
+
+ Requires the original image size in (H, W) format.
+ """
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
+ return boxes.reshape(-1, 4)
+
+ def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
+ """Expects batched images with shape BxCxHxW and float format.
+
+ This transformation may not exactly match apply_image. apply_image is
+ the transformation expected by the model.
+ """
+ # Expects an image in BCHW format. May not exactly match apply_image.
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1],
+ self.target_length)
+ return F.interpolate(
+ image,
+ target_size,
+ mode='bilinear',
+ align_corners=False,
+ antialias=True)
+
+ def apply_coords_torch(self, coords: torch.Tensor,
+ original_size: Tuple[int, ...]) -> torch.Tensor:
+ """Expects a torch tensor with length 2 in the last dimension.
+
+ Requires the original image size in (H, W) format.
+ """
+ old_h, old_w = original_size
+ new_h, new_w = self.get_preprocess_shape(original_size[0],
+ original_size[1],
+ self.target_length)
+ coords = deepcopy(coords).to(torch.float)
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
+ return coords
+
+ def apply_boxes_torch(self, boxes: torch.Tensor,
+ original_size: Tuple[int, ...]) -> torch.Tensor:
+ """Expects a torch tensor with shape Bx4.
+
+ Requires the original image size in (H, W) format.
+ """
+ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
+ return boxes.reshape(-1, 4)
+
+ @staticmethod
+ def get_preprocess_shape(oldh: int, oldw: int,
+ long_side_length: int) -> Tuple[int, int]:
+ """Compute the output size given input size and target long side
+ length."""
+ scale = long_side_length * 1.0 / max(oldh, oldw)
+ newh, neww = oldh * scale, oldw * scale
+ neww = int(neww + 0.5)
+ newh = int(newh + 0.5)
+ return (newh, neww)
diff --git a/projects/sam_inference_demo/sam_image_demo.ipynb b/projects/sam_inference_demo/sam_image_demo.ipynb
new file mode 100644
index 00000000000..1cb433fae9b
--- /dev/null
+++ b/projects/sam_inference_demo/sam_image_demo.ipynb
@@ -0,0 +1,122 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "import cv2\n",
+ "\n",
+ "import sam # noqa: F401\n",
+ "from sam.sam_inferencer import SAMInferencer\n",
+ "\n",
+ "\n",
+ "def show_mask(mask, ax, random_color=False):\n",
+ " if random_color:\n",
+ " color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n",
+ " else:\n",
+ " color = np.array([30/255, 144/255, 255/255, 0.6])\n",
+ " h, w = mask.shape[-2:]\n",
+ " mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n",
+ " ax.imshow(mask_image)\n",
+ " \n",
+ "def show_points(coords, labels, ax, marker_size=375):\n",
+ " pos_points = coords[labels==1]\n",
+ " neg_points = coords[labels==0]\n",
+ " ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)\n",
+ " ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) \n",
+ " \n",
+ "def show_box(box, ax):\n",
+ " x0, y0 = box[0], box[1]\n",
+ " w, h = box[2] - box[0], box[3] - box[1]\n",
+ " ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))\n",
+ "\n",
+ "image = cv2.imread('../../demo/demo.png')\n",
+ "image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
+ "plt.figure(figsize=(10,10))\n",
+ "plt.imshow(image)\n",
+ "plt.axis('on')\n",
+ "plt.show()\n",
+ "print(image.shape)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "inferencer = SAMInferencer(arch='huge')\n",
+ "inferencer.set_image(image)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "input_point = np.array([[280, 230], [500, 300]])\n",
+ "input_label = np.array([1, 1])\n",
+ "plt.figure(figsize=(10,10))\n",
+ "plt.imshow(image)\n",
+ "show_points(input_point, input_label, plt.gca())\n",
+ "plt.axis('on')\n",
+ "plt.show() "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "masks, scores, logits = inferencer.predict(\n",
+ " point_coords=input_point,\n",
+ " point_labels=input_label,\n",
+ " multimask_output=True,\n",
+ ")\n",
+ "for i, (mask, score) in enumerate(zip(masks, scores)):\n",
+ " plt.figure(figsize=(10,10))\n",
+ " plt.imshow(image)\n",
+ " show_mask(mask, plt.gca(), random_color=True)\n",
+ " show_points(input_point, input_label, plt.gca())\n",
+ " plt.title(f\"Mask {i+1}, Score: {score:.3f}\", fontsize=18)\n",
+ " plt.axis('off')\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "pt1.13",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.9"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/projects/van/README.md b/projects/van/README.md
new file mode 100644
index 00000000000..be0ba362faa
--- /dev/null
+++ b/projects/van/README.md
@@ -0,0 +1,101 @@
+# Visual Attention Network (VAN) for Segmentation
+
+This repo is a PyTorch implementation of applying **VAN** (**Visual Attention Network**) to semantic segmentation.
+
+The code is an integration from [VAN-Segmentation](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/README.md?plain=1)
+
+More details can be found in [**Visual Attention Network**](https://arxiv.org/abs/2202.09741).
+
+## Citation
+
+```bib
+@article{guo2022visual,
+ title={Visual Attention Network},
+ author={Guo, Meng-Hao and Lu, Cheng-Ze and Liu, Zheng-Ning and Cheng, Ming-Ming and Hu, Shi-Min},
+ journal={arXiv preprint arXiv:2202.09741},
+ year={2022}
+}
+```
+
+## Results
+
+**Notes**: Pre-trained models can be found in [TsingHua Cloud](https://cloud.tsinghua.edu.cn/d/0100f0cea37d41ba8d08/).
+
+Results can be found in [VAN-Segmentation](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/README.md?plain=1)
+
+We provide evaluation results of the converted weights.
+
+| Method | Backbone | mIoU | Download |
+| :-----: | :----------: | :---: | :--------------------------------------------------------------------------------------------------------------------------------------------: |
+| UPerNet | VAN-B2 | 49.35 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b2-in1kpre_upernet_3rdparty_512x512-ade20k_20230522-19c58aee.pth) |
+| UPerNet | VAN-B3 | 49.71 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b3-in1kpre_upernet_3rdparty_512x512-ade20k_20230522-653bd6b7.pth) |
+| UPerNet | VAN-B4 | 51.56 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b4-in1kpre_upernet_3rdparty_512x512-ade20k_20230522-653bd6b7.pth) |
+| UPerNet | VAN-B4-in22k | 52.61 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b4-in22kpre_upernet_3rdparty_512x512-ade20k_20230522-4a4d744a.pth) |
+| UPerNet | VAN-B5-in22k | 53.11 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b5-in22kpre_upernet_3rdparty_512x512-ade20k_20230522-5bb6f2b4.pth) |
+| UPerNet | VAN-B6-in22k | 54.25 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b6-in22kpre_upernet_3rdparty_512x512-ade20k_20230522-e226b363.pth) |
+| FPN | VAN-B0 | 38.65 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b0-in1kpre_fpn_3rdparty_512x512-ade20k_20230522-75a76298.pth) |
+| FPN | VAN-B1 | 43.22 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b1-in1kpre_fpn_3rdparty_512x512-ade20k_20230522-104499ff.pth) |
+| FPN | VAN-B2 | 46.84 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b2-in1kpre_fpn_3rdparty_512x512-ade20k_20230522-7074e6f8.pth) |
+| FPN | VAN-B3 | 48.32 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b3-in1kpre_fpn_3rdparty_512x512-ade20k_20230522-2c3b7f5e.pth) |
+
+## Preparation
+
+Install MMSegmentation and download ADE20K according to the guidelines in MMSegmentation.
+
+## Requirement
+
+**Step 0.** Install [MMCV](https://github.com/open-mmlab/mmcv) using [MIM](https://github.com/open-mmlab/mim).
+
+```shell
+pip install -U openmim
+mim install mmengine
+mim install "mmcv>=2.0.0"
+```
+
+**Step 1.** Install MMSegmentation.
+
+Case a: If you develop and run mmseg directly, install it from source:
+
+```shell
+git clone -b main https://github.com/open-mmlab/mmsegmentation.git
+cd mmsegmentation
+pip install -v -e .
+```
+
+Case b: If you use mmsegmentation as a dependency or third-party package, install it with pip:
+
+```shell
+pip install "mmsegmentation>=1.0.0"
+```
+
+## Training
+
+If you use 4 GPUs for training by default. Run:
+
+```bash
+bash tools/dist_train.sh projects/van/configs/van/van-b2_pre1k_upernet_4xb2-160k_ade20k-512x512.py 4
+```
+
+## Evaluation
+
+To evaluate the model, an example is:
+
+```bash
+bash tools/dist_train.sh projects/van/configs/van/van-b2_pre1k_upernet_4xb2-160k_ade20k-512x512.py work_dirs/van-b2_pre1k_upernet_4xb2-160k_ade20k-512x512/iter_160000.pth 4 --eval mIoU
+```
+
+## FLOPs
+
+To calculate FLOPs for a model, run:
+
+```bash
+bash tools/analysis_tools/get_flops.py projects/van/configs/van/van-b2_pre1k_upernet_4xb2-160k_ade20k-512x512.py --shape 512 512
+```
+
+## Acknowledgment
+
+Our implementation is mainly based on [mmsegmentation](https://github.com/open-mmlab/mmsegmentation/tree/v0.12.0), [Swin-Transformer](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation), [PoolFormer](https://github.com/sail-sg/poolformer), [Enjoy-Hamburger](https://github.com/Gsunshine/Enjoy-Hamburger) and [VAN-Segmentation](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/README.md?plain=1). Thanks for their authors.
+
+## LICENSE
+
+This repo is under the Apache-2.0 license. For commercial use, please contact the authors.
diff --git a/projects/van/backbones/__init__.py b/projects/van/backbones/__init__.py
new file mode 100644
index 00000000000..071995de294
--- /dev/null
+++ b/projects/van/backbones/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .van import VAN
+
+__all__ = ['VAN']
diff --git a/projects/van/backbones/van.py b/projects/van/backbones/van.py
new file mode 100644
index 00000000000..301834a7587
--- /dev/null
+++ b/projects/van/backbones/van.py
@@ -0,0 +1,124 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+import torch.nn as nn
+from mmengine.model import BaseModule
+
+from mmseg.models.backbones.mscan import (MSCAN, MSCABlock,
+ MSCASpatialAttention,
+ OverlapPatchEmbed)
+from mmseg.registry import MODELS
+
+
+class VANAttentionModule(BaseModule):
+
+ def __init__(self, in_channels):
+ super().__init__()
+ self.conv0 = nn.Conv2d(
+ in_channels, in_channels, 5, padding=2, groups=in_channels)
+ self.conv_spatial = nn.Conv2d(
+ in_channels,
+ in_channels,
+ 7,
+ stride=1,
+ padding=9,
+ groups=in_channels,
+ dilation=3)
+ self.conv1 = nn.Conv2d(in_channels, in_channels, 1)
+
+ def forward(self, x):
+ u = x.clone()
+ attn = self.conv0(x)
+ attn = self.conv_spatial(attn)
+ attn = self.conv1(attn)
+ return u * attn
+
+
+class VANSpatialAttention(MSCASpatialAttention):
+
+ def __init__(self, in_channels, act_cfg=dict(type='GELU')):
+ super().__init__(in_channels, act_cfg=act_cfg)
+ self.spatial_gating_unit = VANAttentionModule(in_channels)
+
+
+class VANBlock(MSCABlock):
+
+ def __init__(self,
+ channels,
+ mlp_ratio=4.,
+ drop=0.,
+ drop_path=0.,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='SyncBN', requires_grad=True)):
+ super().__init__(
+ channels,
+ mlp_ratio=mlp_ratio,
+ drop=drop,
+ drop_path=drop_path,
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg)
+ self.attn = VANSpatialAttention(channels)
+
+
+@MODELS.register_module()
+class VAN(MSCAN):
+
+ def __init__(self,
+ in_channels=3,
+ embed_dims=[64, 128, 256, 512],
+ mlp_ratios=[8, 8, 4, 4],
+ drop_rate=0.,
+ drop_path_rate=0.,
+ depths=[3, 4, 6, 3],
+ num_stages=4,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
+ pretrained=None,
+ init_cfg=None):
+ super(MSCAN, self).__init__(init_cfg=init_cfg)
+
+ assert not (init_cfg and pretrained), \
+ 'init_cfg and pretrained cannot be set at the same time'
+ if isinstance(pretrained, str):
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
+ 'please use "init_cfg" instead')
+ self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
+ elif pretrained is not None:
+ raise TypeError('pretrained must be a str or None')
+
+ self.depths = depths
+ self.num_stages = num_stages
+
+ # stochastic depth decay rule
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+ ]
+ cur = 0
+
+ for i in range(num_stages):
+ patch_embed = OverlapPatchEmbed(
+ patch_size=7 if i == 0 else 3,
+ stride=4 if i == 0 else 2,
+ in_channels=in_channels if i == 0 else embed_dims[i - 1],
+ embed_dim=embed_dims[i],
+ norm_cfg=norm_cfg)
+
+ block = nn.ModuleList([
+ VANBlock(
+ channels=embed_dims[i],
+ mlp_ratio=mlp_ratios[i],
+ drop=drop_rate,
+ drop_path=dpr[cur + j],
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg) for j in range(depths[i])
+ ])
+ norm = nn.LayerNorm(embed_dims[i])
+ cur += depths[i]
+
+ setattr(self, f'patch_embed{i + 1}', patch_embed)
+ setattr(self, f'block{i + 1}', block)
+ setattr(self, f'norm{i + 1}', norm)
+
+ def init_weights(self):
+ return super().init_weights()
diff --git a/projects/van/configs/_base_/datasets/ade20k.py b/projects/van/configs/_base_/datasets/ade20k.py
new file mode 100644
index 00000000000..69b3c2a73b8
--- /dev/null
+++ b/projects/van/configs/_base_/datasets/ade20k.py
@@ -0,0 +1,14 @@
+# dataset settings
+_base_ = '../../../../../configs/_base_/datasets/ade20k.py'
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=(2048, 512), keep_ratio=True),
+ dict(type='ResizeToMultiple', size_divisor=32),
+ # add loading annotation after ``Resize`` because ground truth
+ # does not need to do resize data transform
+ dict(type='LoadAnnotations', reduce_zero_label=True),
+ dict(type='PackSegInputs')
+]
+val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
+test_dataloader = val_dataloader
diff --git a/projects/van/configs/_base_/models/van_fpn.py b/projects/van/configs/_base_/models/van_fpn.py
new file mode 100644
index 00000000000..c7fd7391f77
--- /dev/null
+++ b/projects/van/configs/_base_/models/van_fpn.py
@@ -0,0 +1,43 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+
+data_preprocessor = dict(
+ type='SegDataPreProcessor',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ bgr_to_rgb=True,
+ pad_val=0,
+ seg_pad_val=255,
+ size=(512, 512))
+model = dict(
+ type='EncoderDecoder',
+ data_preprocessor=data_preprocessor,
+ backbone=dict(
+ type='VAN',
+ embed_dims=[32, 64, 160, 256],
+ drop_rate=0.0,
+ drop_path_rate=0.1,
+ depths=[3, 3, 5, 2],
+ act_cfg=dict(type='GELU'),
+ norm_cfg=norm_cfg,
+ init_cfg=dict()),
+ neck=dict(
+ type='FPN',
+ in_channels=[32, 64, 160, 256],
+ out_channels=256,
+ num_outs=4),
+ decode_head=dict(
+ type='FPNHead',
+ in_channels=[256, 256, 256, 256],
+ in_index=[0, 1, 2, 3],
+ feature_strides=[4, 8, 16, 32],
+ channels=128,
+ dropout_ratio=0.1,
+ num_classes=150,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/projects/van/configs/_base_/models/van_upernet.py b/projects/van/configs/_base_/models/van_upernet.py
new file mode 100644
index 00000000000..8f94c0d9d84
--- /dev/null
+++ b/projects/van/configs/_base_/models/van_upernet.py
@@ -0,0 +1,51 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+
+data_preprocessor = dict(
+ type='SegDataPreProcessor',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ bgr_to_rgb=True,
+ pad_val=0,
+ seg_pad_val=255,
+ size=(512, 512))
+model = dict(
+ type='EncoderDecoder',
+ data_preprocessor=data_preprocessor,
+ backbone=dict(
+ type='VAN',
+ embed_dims=[32, 64, 160, 256],
+ drop_rate=0.0,
+ drop_path_rate=0.1,
+ depths=[3, 3, 5, 2],
+ act_cfg=dict(type='GELU'),
+ norm_cfg=norm_cfg,
+ init_cfg=dict()),
+ decode_head=dict(
+ type='UPerHead',
+ in_channels=[32, 64, 160, 256],
+ in_index=[0, 1, 2, 3],
+ pool_scales=(1, 2, 3, 6),
+ channels=512,
+ dropout_ratio=0.1,
+ num_classes=150,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=160,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=150,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/projects/van/configs/van/van-b0_fpn_8xb4-40k_ade20k-512x512.py b/projects/van/configs/van/van-b0_fpn_8xb4-40k_ade20k-512x512.py
new file mode 100644
index 00000000000..2faf3788a71
--- /dev/null
+++ b/projects/van/configs/van/van-b0_fpn_8xb4-40k_ade20k-512x512.py
@@ -0,0 +1,8 @@
+_base_ = './van-b2_fpn_8xb4-40k_ade20k-512x512.py'
+ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b0_3rdparty_20230522-956f5e0d.pth' # noqa
+model = dict(
+ backbone=dict(
+ embed_dims=[32, 64, 160, 256],
+ depths=[3, 3, 5, 2],
+ init_cfg=dict(type='Pretrained', checkpoint=ckpt_path)),
+ neck=dict(in_channels=[32, 64, 160, 256]))
diff --git a/projects/van/configs/van/van-b1_fpn_8xb4-40k_ade20k-512x512.py b/projects/van/configs/van/van-b1_fpn_8xb4-40k_ade20k-512x512.py
new file mode 100644
index 00000000000..cf64a7138b2
--- /dev/null
+++ b/projects/van/configs/van/van-b1_fpn_8xb4-40k_ade20k-512x512.py
@@ -0,0 +1,6 @@
+_base_ = './van-b2_fpn_8xb4-40k_ade20k-512x512.py'
+ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b1_3rdparty_20230522-3adb117f.pth' # noqa
+model = dict(
+ backbone=dict(
+ depths=[2, 2, 4, 2],
+ init_cfg=dict(type='Pretrained', checkpoint=ckpt_path)))
diff --git a/projects/van/configs/van/van-b2_fpn_8xb4-40k_ade20k-512x512.py b/projects/van/configs/van/van-b2_fpn_8xb4-40k_ade20k-512x512.py
new file mode 100644
index 00000000000..965fa1cd363
--- /dev/null
+++ b/projects/van/configs/van/van-b2_fpn_8xb4-40k_ade20k-512x512.py
@@ -0,0 +1,53 @@
+_base_ = [
+ '../_base_/models/van_fpn.py',
+ '../_base_/datasets/ade20k.py',
+ '../../../../configs/_base_/default_runtime.py',
+]
+custom_imports = dict(imports=['projects.van.backbones'])
+ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b2_3rdparty_20230522-636fac93.pth' # noqa
+model = dict(
+ type='EncoderDecoder',
+ backbone=dict(
+ embed_dims=[64, 128, 320, 512],
+ depths=[3, 3, 12, 3],
+ init_cfg=dict(type='Pretrained', checkpoint=ckpt_path),
+ drop_path_rate=0.2),
+ neck=dict(in_channels=[64, 128, 320, 512]),
+ decode_head=dict(num_classes=150))
+
+train_dataloader = dict(batch_size=4)
+
+# we use 8 gpu instead of 4 in mmsegmentation, so lr*2 and max_iters/2
+gpu_multiples = 2
+max_iters = 80000 // gpu_multiples
+interval = 8000 // gpu_multiples
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW',
+ lr=0.0001 * gpu_multiples,
+ # betas=(0.9, 0.999),
+ weight_decay=0.0001),
+ clip_grad=None)
+# learning policy
+param_scheduler = [
+ dict(
+ type='PolyLR',
+ power=0.9,
+ eta_min=0.0,
+ begin=0,
+ end=max_iters,
+ by_epoch=False,
+ )
+]
+train_cfg = dict(
+ type='IterBasedTrainLoop', max_iters=max_iters, val_interval=interval)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=interval),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'))
diff --git a/projects/van/configs/van/van-b2_upernet_4xb2-160k_ade20k-512x512.py b/projects/van/configs/van/van-b2_upernet_4xb2-160k_ade20k-512x512.py
new file mode 100644
index 00000000000..c529606a202
--- /dev/null
+++ b/projects/van/configs/van/van-b2_upernet_4xb2-160k_ade20k-512x512.py
@@ -0,0 +1,46 @@
+_base_ = [
+ '../_base_/models/van_upernet.py', '../_base_/datasets/ade20k.py',
+ '../../../../configs/_base_/default_runtime.py',
+ '../../../../configs/_base_/schedules/schedule_160k.py'
+]
+custom_imports = dict(imports=['projects.van.backbones'])
+ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b2_3rdparty_20230522-636fac93.pth' # noqa
+model = dict(
+ type='EncoderDecoder',
+ backbone=dict(
+ embed_dims=[64, 128, 320, 512],
+ depths=[3, 3, 12, 3],
+ init_cfg=dict(type='Pretrained', checkpoint=ckpt_path)),
+ decode_head=dict(in_channels=[64, 128, 320, 512], num_classes=150),
+ auxiliary_head=dict(in_channels=320, num_classes=150))
+
+# AdamW optimizer
+# no weight decay for position embedding & layer norm in backbone
+optim_wrapper = dict(
+ _delete_=True,
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
+ clip_grad=None,
+ paramwise_cfg=dict(
+ custom_keys={
+ 'absolute_pos_embed': dict(decay_mult=0.),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.)
+ }))
+# learning policy
+param_scheduler = [
+ dict(
+ type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500),
+ dict(
+ type='PolyLR',
+ power=1.0,
+ begin=1500,
+ end=_base_.train_cfg.max_iters,
+ eta_min=0.0,
+ by_epoch=False,
+ )
+]
+
+# By default, models are trained on 8 GPUs with 2 images per GPU
+train_dataloader = dict(batch_size=2)
diff --git a/projects/van/configs/van/van-b3_fpn_8xb4-40k_ade20k-512x512.py b/projects/van/configs/van/van-b3_fpn_8xb4-40k_ade20k-512x512.py
new file mode 100644
index 00000000000..b0493fe4f9f
--- /dev/null
+++ b/projects/van/configs/van/van-b3_fpn_8xb4-40k_ade20k-512x512.py
@@ -0,0 +1,11 @@
+_base_ = './van-b2_fpn_8xb4-40k_ade20k-512x512.py'
+ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b3_3rdparty_20230522-a184e051.pth' # noqa
+model = dict(
+ type='EncoderDecoder',
+ backbone=dict(
+ embed_dims=[64, 128, 320, 512],
+ depths=[3, 5, 27, 3],
+ init_cfg=dict(type='Pretrained', checkpoint=ckpt_path),
+ drop_path_rate=0.3),
+ neck=dict(in_channels=[64, 128, 320, 512]))
+train_dataloader = dict(batch_size=4)
diff --git a/projects/van/configs/van/van-b3_upernet_4xb2-160k_ade20k-512x512.py b/projects/van/configs/van/van-b3_upernet_4xb2-160k_ade20k-512x512.py
new file mode 100644
index 00000000000..8201801d992
--- /dev/null
+++ b/projects/van/configs/van/van-b3_upernet_4xb2-160k_ade20k-512x512.py
@@ -0,0 +1,8 @@
+_base_ = './van-b2_upernet_4xb2-160k_ade20k-512x512.py'
+ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b3_3rdparty_20230522-a184e051.pth' # noqa
+model = dict(
+ type='EncoderDecoder',
+ backbone=dict(
+ depths=[3, 5, 27, 3],
+ init_cfg=dict(type='Pretrained', checkpoint=ckpt_path),
+ drop_path_rate=0.3))
diff --git a/projects/van/configs/van/van-b4-in22kpre_upernet_4xb4-160k_ade20k-512x512.py b/projects/van/configs/van/van-b4-in22kpre_upernet_4xb4-160k_ade20k-512x512.py
new file mode 100644
index 00000000000..15c8f7ca6ee
--- /dev/null
+++ b/projects/van/configs/van/van-b4-in22kpre_upernet_4xb4-160k_ade20k-512x512.py
@@ -0,0 +1,10 @@
+_base_ = './van-b2_upernet_4xb2-160k_ade20k-512x512.py'
+ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b4-in22k_3rdparty_20230522-5e31cafb.pth' # noqa
+model = dict(
+ backbone=dict(
+ depths=[3, 6, 40, 3],
+ init_cfg=dict(type='Pretrained', checkpoint=ckpt_path),
+ drop_path_rate=0.4))
+
+# By default, models are trained on 8 GPUs with 2 images per GPU
+train_dataloader = dict(batch_size=4)
diff --git a/projects/van/configs/van/van-b4_upernet_4xb4-160k_ade20k-512x512.py b/projects/van/configs/van/van-b4_upernet_4xb4-160k_ade20k-512x512.py
new file mode 100644
index 00000000000..33ae049d0c0
--- /dev/null
+++ b/projects/van/configs/van/van-b4_upernet_4xb4-160k_ade20k-512x512.py
@@ -0,0 +1,10 @@
+_base_ = './van-b2_upernet_4xb2-160k_ade20k-512x512.py'
+ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b4_3rdparty_20230522-1d71c077.pth' # noqa
+model = dict(
+ backbone=dict(
+ depths=[3, 6, 40, 3],
+ init_cfg=dict(type='Pretrained', checkpoint=ckpt_path),
+ drop_path_rate=0.4))
+
+# By default, models are trained on 4 GPUs with 4 images per GPU
+train_dataloader = dict(batch_size=4)
diff --git a/projects/van/configs/van/van-b5-in22kpre_upernet_4xb2-160k_ade20k-512x512.py b/projects/van/configs/van/van-b5-in22kpre_upernet_4xb2-160k_ade20k-512x512.py
new file mode 100644
index 00000000000..f36c6242bdf
--- /dev/null
+++ b/projects/van/configs/van/van-b5-in22kpre_upernet_4xb2-160k_ade20k-512x512.py
@@ -0,0 +1,10 @@
+_base_ = './van-b2_upernet_4xb2-160k_ade20k-512x512.py'
+ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b5-in22k_3rdparty_20230522-b26134d7.pth' # noqa
+model = dict(
+ backbone=dict(
+ embed_dims=[96, 192, 480, 768],
+ depths=[3, 3, 24, 3],
+ init_cfg=dict(type='Pretrained', checkpoint=ckpt_path),
+ drop_path_rate=0.4),
+ decode_head=dict(in_channels=[96, 192, 480, 768], num_classes=150),
+ auxiliary_head=dict(in_channels=480, num_classes=150))
diff --git a/projects/van/configs/van/van-b6-in22kpre_upernet_4xb2-160k_ade20k-512x512.py b/projects/van/configs/van/van-b6-in22kpre_upernet_4xb2-160k_ade20k-512x512.py
new file mode 100644
index 00000000000..aa529efed8c
--- /dev/null
+++ b/projects/van/configs/van/van-b6-in22kpre_upernet_4xb2-160k_ade20k-512x512.py
@@ -0,0 +1,10 @@
+_base_ = './van-b2_upernet_4xb2-160k_ade20k-512x512.py'
+ckpt_path = 'https://download.openmmlab.com/mmsegmentation/v0.5/van_3rdparty/van-b6-in22k_3rdparty_20230522-5e5172a3.pth' # noqa
+model = dict(
+ backbone=dict(
+ embed_dims=[96, 192, 384, 768],
+ depths=[6, 6, 90, 6],
+ init_cfg=dict(type='Pretrained', checkpoint=ckpt_path),
+ drop_path_rate=0.5),
+ decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=150),
+ auxiliary_head=dict(in_channels=384, num_classes=150))
diff --git a/requirements.txt b/requirements.txt
index 6da5adea757..501bddc8843 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,4 @@
-r requirements/optional.txt
-r requirements/runtime.txt
-r requirements/tests.txt
+-r requirements/multimodal.txt
diff --git a/requirements/albu.txt b/requirements/albu.txt
new file mode 100644
index 00000000000..f421fbbdc47
--- /dev/null
+++ b/requirements/albu.txt
@@ -0,0 +1 @@
+albumentations>=0.3.2 --no-binary qudida,albumentations
diff --git a/requirements/docs.txt b/requirements/docs.txt
index 8e98c16fc72..19632d36aba 100644
--- a/requirements/docs.txt
+++ b/requirements/docs.txt
@@ -4,3 +4,4 @@ myst-parser
sphinx==4.0.2
sphinx_copybutton
sphinx_markdown_tables
+urllib3<2.0.0
diff --git a/requirements/multimodal.txt b/requirements/multimodal.txt
new file mode 100644
index 00000000000..2195d0d9ef8
--- /dev/null
+++ b/requirements/multimodal.txt
@@ -0,0 +1,2 @@
+ftfy
+regex
diff --git a/requirements/optional.txt b/requirements/optional.txt
index 5eca649247e..b0310f52960 100644
--- a/requirements/optional.txt
+++ b/requirements/optional.txt
@@ -1,2 +1,22 @@
cityscapesscripts
+-e git+https://github.com/openai/CLIP.git@main#egg=clip
+
+# for vpd model
+diffusers
+einops==0.3.0
+imageio==2.9.0
+imageio-ffmpeg==0.4.2
+invisible-watermark
+kornia==0.6
+-e git+https://github.com/CompVis/stable-diffusion@21f890f#egg=latent-diffusion
nibabel
+omegaconf==2.1.1
+pudb==2019.2
+pytorch-lightning==1.4.2
+streamlit>=0.73.1
+-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
+test-tube>=0.7.5
+timm
+torch-fidelity==0.3.0
+torchmetrics==0.6.0
+transformers==4.19.2
diff --git a/requirements/tests.txt b/requirements/tests.txt
index 74fc76146d8..3fff2520d7c 100644
--- a/requirements/tests.txt
+++ b/requirements/tests.txt
@@ -1,6 +1,8 @@
codecov
flake8
+ftfy
interrogate
pytest
+regex
xdoctest>=0.10.0
yapf
diff --git a/resources/miaomiao_qrcode.jpg b/resources/miaomiao_qrcode.jpg
new file mode 100644
index 00000000000..d34cbae6fd1
Binary files /dev/null and b/resources/miaomiao_qrcode.jpg differ
diff --git a/setup.cfg b/setup.cfg
index dc5ea071110..2ea07600c0b 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -16,4 +16,4 @@ default_section = THIRDPARTY
skip = *.po,*.ts,*.ipynb
count =
quiet-level = 3
-ignore-words-list = formating,sur,hist,dota,warmup
+ignore-words-list = formating,sur,hist,dota,warmup,damon
diff --git a/setup.py b/setup.py
index 854dd186054..45d923db60e 100755
--- a/setup.py
+++ b/setup.py
@@ -123,7 +123,7 @@ def add_mim_extension():
else:
return
- filenames = ['tools', 'configs', 'model-index.yml']
+ filenames = ['tools', 'configs', 'model-index.yml', 'dataset-index.yml']
repo_path = osp.dirname(__file__)
mim_path = osp.join(repo_path, 'mmseg', '.mim')
os.makedirs(mim_path, exist_ok=True)
@@ -175,7 +175,7 @@ def add_mim_extension():
author='MMSegmentation Contributors',
author_email='openmmlab@gmail.com',
keywords='computer vision, semantic segmentation',
- url='http://github.com/open-mmlab/mmsegmentation',
+ url='https://github.com/open-mmlab/mmsegmentation',
packages=find_packages(exclude=('configs', 'tools', 'demo')),
include_package_data=True,
classifiers=[
@@ -194,6 +194,7 @@ def add_mim_extension():
'tests': parse_requirements('requirements/tests.txt'),
'optional': parse_requirements('requirements/optional.txt'),
'mim': parse_requirements('requirements/mminstall.txt'),
+ 'multimodal': parse_requirements('requirements/multimodal.txt'),
},
ext_modules=[],
zip_safe=False)
diff --git a/tests/data/dsdl_seg/config.py b/tests/data/dsdl_seg/config.py
new file mode 100755
index 00000000000..8eed751c2f7
--- /dev/null
+++ b/tests/data/dsdl_seg/config.py
@@ -0,0 +1,13 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+local = dict(
+ type='LocalFileReader',
+ working_dir='/nvme/share_data/VOC2012',
+)
+
+ali_oss = dict(
+ type='AliOSSFileReader',
+ access_key_secret='your secret key of aliyun oss',
+ endpoint='your endpoint of aliyun oss',
+ access_key_id='your access key of aliyun oss',
+ bucket_name='your bucket name of aliyun oss',
+ working_dir='the relative path of your media dir in the bucket')
diff --git a/tests/data/dsdl_seg/defs/class-dom.yaml b/tests/data/dsdl_seg/defs/class-dom.yaml
new file mode 100755
index 00000000000..e5dd598c4ae
--- /dev/null
+++ b/tests/data/dsdl_seg/defs/class-dom.yaml
@@ -0,0 +1,24 @@
+$dsdl-version: "0.5.0"
+VOCClassDom:
+ $def: class_domain
+ classes:
+ - aeroplane
+ - bicycle
+ - bird
+ - boat
+ - bottle
+ - bus
+ - car
+ - cat
+ - chair
+ - cow
+ - diningtable
+ - dog
+ - horse
+ - motorbike
+ - person
+ - pottedplant
+ - sheep
+ - sofa
+ - train
+ - tvmonitor
diff --git a/tests/data/dsdl_seg/defs/segmentation-def.yaml b/tests/data/dsdl_seg/defs/segmentation-def.yaml
new file mode 100755
index 00000000000..057139ed57e
--- /dev/null
+++ b/tests/data/dsdl_seg/defs/segmentation-def.yaml
@@ -0,0 +1,15 @@
+$dsdl-version: "0.5.0"
+
+ImageMedia:
+ $def: struct
+ $fields:
+ image: Image
+ image_shape: ImageShape
+
+SegmentationSample:
+ $def: struct
+ $params: ['cdom']
+ $fields:
+ media: ImageMedia
+ label_map: LabelMap[dom=$cdom]
+ instance_map: InstanceMap
diff --git a/tests/data/dsdl_seg/set-train/train.yaml b/tests/data/dsdl_seg/set-train/train.yaml
new file mode 100755
index 00000000000..69872445a53
--- /dev/null
+++ b/tests/data/dsdl_seg/set-train/train.yaml
@@ -0,0 +1,15 @@
+$dsdl-version: "0.5.0"
+$import:
+ - ../defs/segmentation-def
+ - ../defs/class-dom
+meta:
+ dataset_name: "VOC2012"
+ sub_dataset_name: "train"
+ task_type: "Segmentation"
+ dataset_homepage: "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html"
+ dataset_publisher: "University of Leeds | ETHZ, Zurich | University of Edinburgh\
+ \ |Microsoft Research Cambridge | University of Oxford"
+ OpenDataLab_address: "https://opendatalab.com/PASCAL_VOC2012/download"
+data:
+ sample-type: SegmentationSample[cdom=VOCClassDom]
+ sample-path: train_samples.json
diff --git a/tests/data/dsdl_seg/set-train/train_samples.json b/tests/data/dsdl_seg/set-train/train_samples.json
new file mode 100755
index 00000000000..559f5845721
--- /dev/null
+++ b/tests/data/dsdl_seg/set-train/train_samples.json
@@ -0,0 +1 @@
+{"samples": [{"media": {"image": "JPEGImages/2007_000032.jpg", "image_shape": [281, 500]}, "label_map": "SegmentationClass/2007_000032.png", "instance_map": "SegmentationObject/2007_000032.png"}, {"media": {"image": "JPEGImages/2007_000039.jpg", "image_shape": [375, 500]}, "label_map": "SegmentationClass/2007_000039.png", "instance_map": "SegmentationObject/2007_000039.png"}, {"media": {"image": "JPEGImages/2007_000063.jpg", "image_shape": [375, 500]}, "label_map": "SegmentationClass/2007_000063.png", "instance_map": "SegmentationObject/2007_000063.png"}]}
diff --git a/tests/data/pseudo_bdd100k_dataset/images/10k/train/0004a4c0-d4dff0ad.jpg b/tests/data/pseudo_bdd100k_dataset/images/10k/train/0004a4c0-d4dff0ad.jpg
new file mode 100644
index 00000000000..4724a3d9302
Binary files /dev/null and b/tests/data/pseudo_bdd100k_dataset/images/10k/train/0004a4c0-d4dff0ad.jpg differ
diff --git a/tests/data/pseudo_bdd100k_dataset/images/10k/train/00054602-3bf57337.jpg b/tests/data/pseudo_bdd100k_dataset/images/10k/train/00054602-3bf57337.jpg
new file mode 100644
index 00000000000..5efe06b99a4
Binary files /dev/null and b/tests/data/pseudo_bdd100k_dataset/images/10k/train/00054602-3bf57337.jpg differ
diff --git a/tests/data/pseudo_bdd100k_dataset/images/10k/train/00067cfb-e535423e.jpg b/tests/data/pseudo_bdd100k_dataset/images/10k/train/00067cfb-e535423e.jpg
new file mode 100644
index 00000000000..2233c03c763
Binary files /dev/null and b/tests/data/pseudo_bdd100k_dataset/images/10k/train/00067cfb-e535423e.jpg differ
diff --git a/tests/data/pseudo_bdd100k_dataset/images/10k/val/7d06fefd-f7be05a6.jpg b/tests/data/pseudo_bdd100k_dataset/images/10k/val/7d06fefd-f7be05a6.jpg
new file mode 100644
index 00000000000..535087b5a95
Binary files /dev/null and b/tests/data/pseudo_bdd100k_dataset/images/10k/val/7d06fefd-f7be05a6.jpg differ
diff --git a/tests/data/pseudo_bdd100k_dataset/images/10k/val/7d128593-0ccfea4c.jpg b/tests/data/pseudo_bdd100k_dataset/images/10k/val/7d128593-0ccfea4c.jpg
new file mode 100644
index 00000000000..7f2971afdee
Binary files /dev/null and b/tests/data/pseudo_bdd100k_dataset/images/10k/val/7d128593-0ccfea4c.jpg differ
diff --git a/tests/data/pseudo_bdd100k_dataset/images/10k/val/7d15b18b-1e0d6e3f.jpg b/tests/data/pseudo_bdd100k_dataset/images/10k/val/7d15b18b-1e0d6e3f.jpg
new file mode 100644
index 00000000000..31a951d4830
Binary files /dev/null and b/tests/data/pseudo_bdd100k_dataset/images/10k/val/7d15b18b-1e0d6e3f.jpg differ
diff --git a/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/colormaps/train/0004a4c0-d4dff0ad.png b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/colormaps/train/0004a4c0-d4dff0ad.png
new file mode 100644
index 00000000000..086a8d50647
Binary files /dev/null and b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/colormaps/train/0004a4c0-d4dff0ad.png differ
diff --git a/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/colormaps/train/00054602-3bf57337.png b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/colormaps/train/00054602-3bf57337.png
new file mode 100644
index 00000000000..43338c283cc
Binary files /dev/null and b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/colormaps/train/00054602-3bf57337.png differ
diff --git a/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/colormaps/train/00067cfb-e535423e.png b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/colormaps/train/00067cfb-e535423e.png
new file mode 100644
index 00000000000..7c0ad1d5d90
Binary files /dev/null and b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/colormaps/train/00067cfb-e535423e.png differ
diff --git a/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/colormaps/val/7d128593-0ccfea4c.png b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/colormaps/val/7d128593-0ccfea4c.png
new file mode 100644
index 00000000000..43338c283cc
Binary files /dev/null and b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/colormaps/val/7d128593-0ccfea4c.png differ
diff --git a/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/colormaps/val/7d15b18b-1e0d6e3f.png b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/colormaps/val/7d15b18b-1e0d6e3f.png
new file mode 100644
index 00000000000..43338c283cc
Binary files /dev/null and b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/colormaps/val/7d15b18b-1e0d6e3f.png differ
diff --git a/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/colormaps/val/7d2f7975-e0c1c5a7.png b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/colormaps/val/7d2f7975-e0c1c5a7.png
new file mode 100644
index 00000000000..7c0ad1d5d90
Binary files /dev/null and b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/colormaps/val/7d2f7975-e0c1c5a7.png differ
diff --git a/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/masks/train/0004a4c0-d4dff0ad.png b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/masks/train/0004a4c0-d4dff0ad.png
new file mode 100644
index 00000000000..5c6bf5e1581
Binary files /dev/null and b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/masks/train/0004a4c0-d4dff0ad.png differ
diff --git a/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/masks/train/00054602-3bf57337.png b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/masks/train/00054602-3bf57337.png
new file mode 100644
index 00000000000..c525a768885
Binary files /dev/null and b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/masks/train/00054602-3bf57337.png differ
diff --git a/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/masks/train/00067cfb-e535423e.png b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/masks/train/00067cfb-e535423e.png
new file mode 100644
index 00000000000..7dfd3af4e3a
Binary files /dev/null and b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/masks/train/00067cfb-e535423e.png differ
diff --git a/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/masks/val/7d06fefd-f7be05a6.png b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/masks/val/7d06fefd-f7be05a6.png
new file mode 100644
index 00000000000..7dfd3af4e3a
Binary files /dev/null and b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/masks/val/7d06fefd-f7be05a6.png differ
diff --git a/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/masks/val/7d128593-0ccfea4c.png b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/masks/val/7d128593-0ccfea4c.png
new file mode 100644
index 00000000000..c525a768885
Binary files /dev/null and b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/masks/val/7d128593-0ccfea4c.png differ
diff --git a/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/masks/val/7d15b18b-1e0d6e3f.png b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/masks/val/7d15b18b-1e0d6e3f.png
new file mode 100644
index 00000000000..c525a768885
Binary files /dev/null and b/tests/data/pseudo_bdd100k_dataset/labels/sem_seg/masks/val/7d15b18b-1e0d6e3f.png differ
diff --git a/tests/data/pseudo_nyu_dataset/annotations/bookstore_0001d_00001.png b/tests/data/pseudo_nyu_dataset/annotations/bookstore_0001d_00001.png
new file mode 100644
index 00000000000..77e343603ad
Binary files /dev/null and b/tests/data/pseudo_nyu_dataset/annotations/bookstore_0001d_00001.png differ
diff --git a/tests/data/pseudo_nyu_dataset/images/bookstore_0001d_00001.jpg b/tests/data/pseudo_nyu_dataset/images/bookstore_0001d_00001.jpg
new file mode 100644
index 00000000000..7892ed47e7d
Binary files /dev/null and b/tests/data/pseudo_nyu_dataset/images/bookstore_0001d_00001.jpg differ
diff --git a/tests/test_apis/test_inferencer.py b/tests/test_apis/test_inferencer.py
index 497eae4a011..d8dbce8f385 100644
--- a/tests/test_apis/test_inferencer.py
+++ b/tests/test_apis/test_inferencer.py
@@ -3,76 +3,16 @@
import numpy as np
import torch
-import torch.nn as nn
from mmengine import ConfigDict
-from torch.utils.data import DataLoader, Dataset
+from utils import * # noqa: F401, F403
from mmseg.apis import MMSegInferencer
-from mmseg.models import EncoderDecoder
-from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.registry import MODELS
from mmseg.utils import register_all_modules
-@MODELS.register_module(name='InferExampleHead')
-class ExampleDecodeHead(BaseDecodeHead):
-
- def __init__(self, num_classes=19, out_channels=None):
- super().__init__(
- 3, 3, num_classes=num_classes, out_channels=out_channels)
-
- def forward(self, inputs):
- return self.cls_seg(inputs[0])
-
-
-@MODELS.register_module(name='InferExampleBackbone')
-class ExampleBackbone(nn.Module):
-
- def __init__(self):
- super().__init__()
- self.conv = nn.Conv2d(3, 3, 3)
-
- def init_weights(self, pretrained=None):
- pass
-
- def forward(self, x):
- return [self.conv(x)]
-
-
-@MODELS.register_module(name='InferExampleModel')
-class ExampleModel(EncoderDecoder):
-
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
-
-
-class ExampleDataset(Dataset):
-
- def __init__(self) -> None:
- super().__init__()
- self.pipeline = [
- dict(type='LoadImageFromFile'),
- dict(type='LoadAnnotations'),
- dict(type='PackSegInputs')
- ]
-
- def __getitem__(self, idx):
- return dict(img=torch.tensor([1]), img_metas=dict())
-
- def __len__(self):
- return 1
-
-
def test_inferencer():
register_all_modules()
- test_dataset = ExampleDataset()
- data_loader = DataLoader(
- test_dataset,
- batch_size=1,
- sampler=None,
- num_workers=0,
- shuffle=False,
- )
visualizer = dict(
type='SegLocalVisualizer',
@@ -87,7 +27,14 @@ def test_inferencer():
decode_head=dict(type='InferExampleHead'),
test_cfg=dict(mode='whole')),
visualizer=visualizer,
- test_dataloader=data_loader)
+ test_dataloader=dict(
+ dataset=dict(
+ type='ExampleDataset',
+ pipeline=[
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+ ]), ))
cfg = ConfigDict(cfg_dict)
model = MODELS.build(cfg.model)
diff --git a/tests/test_apis/test_rs_inferencer.py b/tests/test_apis/test_rs_inferencer.py
new file mode 100644
index 00000000000..03423d9680e
--- /dev/null
+++ b/tests/test_apis/test_rs_inferencer.py
@@ -0,0 +1,73 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from unittest import TestCase
+
+import numpy as np
+from mmengine import ConfigDict, init_default_scope
+from utils import * # noqa: F401, F403
+
+from mmseg.apis import RSImage, RSInferencer
+from mmseg.registry import MODELS
+
+
+class TestRSImage(TestCase):
+
+ def test_read_whole_image(self):
+ init_default_scope('mmseg')
+ img_path = osp.join(
+ osp.dirname(__file__),
+ '../data/pseudo_loveda_dataset/img_dir/0.png')
+ rs_image = RSImage(img_path)
+ window_size = (16, 16)
+ rs_image.create_grids(window_size)
+ image_data = rs_image.read(rs_image.grids[0])
+ self.assertIsNotNone(image_data)
+
+ def test_write_image_data(self):
+ init_default_scope('mmseg')
+ img_path = osp.join(
+ osp.dirname(__file__),
+ '../data/pseudo_loveda_dataset/img_dir/0.png')
+ rs_image = RSImage(img_path)
+ window_size = (16, 16)
+ rs_image.create_grids(window_size)
+ data = np.random.random((16, 16)).astype(np.int8)
+ rs_image.write(data, rs_image.grids[0])
+
+
+class TestRSInferencer(TestCase):
+
+ def test_read_and_inference(self):
+ init_default_scope('mmseg')
+ cfg_dict = dict(
+ model=dict(
+ type='InferExampleModel',
+ data_preprocessor=dict(type='SegDataPreProcessor'),
+ backbone=dict(type='InferExampleBackbone'),
+ decode_head=dict(type='InferExampleHead'),
+ test_cfg=dict(mode='whole')),
+ test_dataloader=dict(
+ dataset=dict(
+ type='ExampleDataset',
+ pipeline=[
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+ ])),
+ test_pipeline=[
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='PackSegInputs')
+ ])
+ cfg = ConfigDict(cfg_dict)
+ model = MODELS.build(cfg.model)
+ model.cfg = cfg
+ inferencer = RSInferencer.from_model(model)
+
+ img_path = osp.join(
+ osp.dirname(__file__),
+ '../data/pseudo_loveda_dataset/img_dir/0.png')
+ rs_image = RSImage(img_path)
+ window_size = (16, 16)
+ stride = (16, 16)
+ inferencer.run(rs_image, window_size, stride)
diff --git a/tests/test_apis/utils.py b/tests/test_apis/utils.py
new file mode 100644
index 00000000000..0a9928fccf1
--- /dev/null
+++ b/tests/test_apis/utils.py
@@ -0,0 +1,38 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+from mmseg.models import EncoderDecoder
+from mmseg.models.decode_heads.decode_head import BaseDecodeHead
+from mmseg.registry import MODELS
+
+
+@MODELS.register_module(name='InferExampleHead')
+class ExampleDecodeHead(BaseDecodeHead):
+
+ def __init__(self, num_classes=19, out_channels=None):
+ super().__init__(
+ 3, 3, num_classes=num_classes, out_channels=out_channels)
+
+ def forward(self, inputs):
+ return self.cls_seg(inputs[0])
+
+
+@MODELS.register_module(name='InferExampleBackbone')
+class ExampleBackbone(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.conv = nn.Conv2d(3, 3, 3)
+
+ def init_weights(self, pretrained=None):
+ pass
+
+ def forward(self, x):
+ return [self.conv(x)]
+
+
+@MODELS.register_module(name='InferExampleModel')
+class ExampleModel(EncoderDecoder):
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
diff --git a/tests/test_config.py b/tests/test_config.py
index 13de460181c..cdd85ff57cd 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -30,6 +30,7 @@ def _get_config_directory():
def test_config_build_segmentor():
"""Test that all segmentation models defined in the configs can be
initialized."""
+ init_default_scope('mmseg')
config_dpath = _get_config_directory()
print(f'Found config_dpath = {config_dpath!r}')
@@ -92,7 +93,12 @@ def test_config_data_pipeline():
del config_mod.train_pipeline[0]
del config_mod.test_pipeline[0]
# remove loading annotation in test pipeline
- del config_mod.test_pipeline[-2]
+ load_anno_idx = -1
+ for i in range(len(config_mod.test_pipeline)):
+ if config_mod.test_pipeline[i].type in ('LoadAnnotations',
+ 'LoadDepthAnnotation'):
+ load_anno_idx = i
+ del config_mod.test_pipeline[load_anno_idx]
train_pipeline = Compose(config_mod.train_pipeline)
test_pipeline = Compose(config_mod.test_pipeline)
@@ -101,6 +107,7 @@ def test_config_data_pipeline():
if to_float32:
img = img.astype(np.float32)
seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8)
+ depth = np.random.rand(1024, 2048).astype(np.float32)
results = dict(
filename='test_img.png',
@@ -108,24 +115,30 @@ def test_config_data_pipeline():
img=img,
img_shape=img.shape,
ori_shape=img.shape,
- gt_seg_map=seg)
+ gt_seg_map=seg,
+ gt_depth_map=depth)
results['seg_fields'] = ['gt_seg_map']
-
+ _check_concat_cd_input(config_mod, results)
print(f'Test training data pipeline: \n{train_pipeline!r}')
output_results = train_pipeline(results)
assert output_results is not None
- results = dict(
- filename='test_img.png',
- ori_filename='test_img.png',
- img=img,
- img_shape=img.shape,
- ori_shape=img.shape)
+ _check_concat_cd_input(config_mod, results)
print(f'Test testing data pipeline: \n{test_pipeline!r}')
output_results = test_pipeline(results)
assert output_results is not None
+def _check_concat_cd_input(config_mod: Config, results: dict):
+ keys = []
+ pipeline = config_mod.train_pipeline.copy()
+ pipeline.extend(config_mod.test_pipeline)
+ for t in pipeline:
+ keys.append(t.type)
+ if 'ConcatCDInput' in keys:
+ results.update({'img2': results['img']})
+
+
def _check_decode_head(decode_head_cfg, decode_head):
if isinstance(decode_head_cfg, list):
assert isinstance(decode_head, nn.ModuleList)
@@ -149,14 +162,14 @@ def _check_decode_head(decode_head_cfg, decode_head):
elif input_transform == 'resize_concat':
assert sum(in_channels) == decode_head.in_channels
else:
- assert isinstance(in_channels, int)
assert in_channels == decode_head.in_channels
- assert isinstance(decode_head.in_index, int)
if decode_head_cfg['type'] == 'PointHead':
assert decode_head_cfg.channels+decode_head_cfg.num_classes == \
decode_head.fc_seg.in_channels
assert decode_head.fc_seg.out_channels == decode_head_cfg.num_classes
+ elif decode_head_cfg['type'] == 'VPDDepthHead':
+ assert decode_head.out_channels == 1
else:
assert decode_head_cfg.channels == decode_head.conv_seg.in_channels
assert decode_head.conv_seg.out_channels == decode_head_cfg.num_classes
diff --git a/tests/test_datasets/test_dataset.py b/tests/test_datasets/test_dataset.py
index db4a7799069..2904e09cedd 100644
--- a/tests/test_datasets/test_dataset.py
+++ b/tests/test_datasets/test_dataset.py
@@ -5,15 +5,21 @@
import pytest
-from mmseg.datasets import (ADE20KDataset, BaseSegDataset, CityscapesDataset,
- COCOStuffDataset, DecathlonDataset, ISPRSDataset,
+from mmseg.datasets import (ADE20KDataset, BaseSegDataset, BDD100KDataset,
+ CityscapesDataset, COCOStuffDataset,
+ DecathlonDataset, DSDLSegDataset, ISPRSDataset,
LIPDataset, LoveDADataset, MapillaryDataset_v1,
- MapillaryDataset_v2, PascalVOCDataset,
+ MapillaryDataset_v2, NYUDataset, PascalVOCDataset,
PotsdamDataset, REFUGEDataset, SynapseDataset,
iSAIDDataset)
from mmseg.registry import DATASETS
from mmseg.utils import get_classes, get_palette
+try:
+ from dsdl.dataset import DSDLDataset
+except ImportError:
+ DSDLDataset = None
+
def test_classes():
assert list(
@@ -32,7 +38,7 @@ def test_classes():
MapillaryDataset_v1.METAINFO['classes']) == get_classes('mapillary_v1')
assert list(
MapillaryDataset_v2.METAINFO['classes']) == get_classes('mapillary_v2')
-
+ assert list(BDD100KDataset.METAINFO['classes']) == get_classes('bdd100k')
with pytest.raises(ValueError):
get_classes('unsupported')
@@ -89,7 +95,7 @@ def test_palette():
MapillaryDataset_v1.METAINFO['palette']) == get_palette('mapillary_v1')
assert list(
MapillaryDataset_v2.METAINFO['palette']) == get_palette('mapillary_v2')
-
+ assert list(BDD100KDataset.METAINFO['palette']) == get_palette('bdd100k')
with pytest.raises(ValueError):
get_palette('unsupported')
@@ -326,6 +332,19 @@ def test_mapillary():
assert len(test_dataset) == 1
+def test_bdd100k():
+ test_dataset = BDD100KDataset(
+ pipeline=[],
+ data_prefix=dict(
+ img_path=osp.join(
+ osp.dirname(__file__),
+ '../data/pseudo_bdd100k_dataset/images/10k/val'),
+ seg_map_path=osp.join(
+ osp.dirname(__file__),
+ '../data/pseudo_bdd100k_dataset/labels/sem_seg/masks/val')))
+ assert len(test_dataset) == 3
+
+
@pytest.mark.parametrize('dataset, classes', [
('ADE20KDataset', ('wall', 'building')),
('CityscapesDataset', ('road', 'sidewalk')),
@@ -433,3 +452,24 @@ def test_custom_dataset_custom_palette():
ann_file=tempfile.mkdtemp(),
metainfo=dict(classes=('bus', 'car'), palette=[[200, 200, 200]]),
lazy_init=True)
+
+
+def test_dsdlseg_dataset():
+ if DSDLDataset is not None:
+ dataset = DSDLSegDataset(
+ data_root='tests/data/dsdl_seg', ann_file='set-train/train.yaml')
+ assert len(dataset) == 3
+ assert len(dataset.metainfo['classes']) == 21
+ else:
+ ImportWarning('Package `dsdl` is not installed.')
+
+
+def test_nyu_dataset():
+ dataset = NYUDataset(
+ data_root='tests/data/pseudo_nyu_dataset',
+ data_prefix=dict(img_path='images', depth_map_path='annotations'),
+ )
+ assert len(dataset) == 1
+ data = dataset[0]
+ assert data.get('depth_map_path', None) is not None
+ assert data.get('category_id', -1) == 26
diff --git a/tests/test_datasets/test_loading.py b/tests/test_datasets/test_loading.py
index 5ce624bff6a..3eea6e3f9dd 100644
--- a/tests/test_datasets/test_loading.py
+++ b/tests/test_datasets/test_loading.py
@@ -7,10 +7,11 @@
import numpy as np
from mmcv.transforms import LoadImageFromFile
-from mmseg.datasets.transforms import (LoadAnnotations,
- LoadBiomedicalAnnotation,
+from mmseg.datasets.transforms import LoadAnnotations # noqa
+from mmseg.datasets.transforms import (LoadBiomedicalAnnotation,
LoadBiomedicalData,
LoadBiomedicalImageFromFile,
+ LoadDepthAnnotation,
LoadImageFromNDArray)
@@ -276,3 +277,19 @@ def test_load_biomedical_data(self):
"decode_backend='numpy', "
'to_xyz=False, '
'backend_args=None)')
+
+ def test_load_depth_annotation(self):
+ input_results = dict(
+ img_path='tests/data/pseudo_nyu_dataset/images/'
+ 'bookstore_0001d_00001.jpg',
+ depth_map_path='tests/data/pseudo_nyu_dataset/'
+ 'annotations/bookstore_0001d_00001.png',
+ category_id=-1,
+ seg_fields=[])
+ transform = LoadDepthAnnotation(depth_rescale_factor=0.001)
+ results = transform(input_results)
+ assert 'gt_depth_map' in results
+ assert results['gt_depth_map'].shape[:2] == mmcv.imread(
+ input_results['depth_map_path']).shape[:2]
+ assert results['gt_depth_map'].dtype == np.float32
+ assert 'gt_depth_map' in results['seg_fields']
diff --git a/tests/test_datasets/test_transform.py b/tests/test_datasets/test_transform.py
index 92d6c6106d2..e73e558ee8e 100644
--- a/tests/test_datasets/test_transform.py
+++ b/tests/test_datasets/test_transform.py
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
+from unittest import TestCase
import mmcv
import numpy as np
@@ -11,7 +12,8 @@
from mmseg.datasets.transforms import * # noqa
from mmseg.datasets.transforms import (LoadBiomedicalData,
LoadBiomedicalImageFromFile,
- PhotoMetricDistortion, RandomCrop)
+ PhotoMetricDistortion, RandomCrop,
+ RandomDepthMix)
from mmseg.registry import TRANSFORMS
init_default_scope('mmseg')
@@ -184,6 +186,14 @@ def test_flip():
assert np.equal(original_img, results['img']).all()
assert np.equal(original_seg, results['gt_semantic_seg']).all()
+ results['gt_depth_map'] = seg
+ results['seg_fields'] = ['gt_depth_map']
+ results = flip_module(results)
+ flip_module = TRANSFORMS.build(transform)
+ results = flip_module(results)
+ assert np.equal(original_img, results['img']).all()
+ assert np.equal(original_seg, results['gt_depth_map']).all()
+
def test_random_rotate_flip():
with pytest.raises(AssertionError):
@@ -1160,3 +1170,104 @@ def test_biomedical_3d_flip():
results = transform(results)
assert np.equal(original_img, results['img']).all()
assert np.equal(original_seg, results['gt_seg_map']).all()
+
+
+def test_albu_transform():
+ results = dict(
+ img_path=osp.join(osp.dirname(__file__), '../data/color.jpg'))
+
+ # Define simple pipeline
+ load = dict(type='LoadImageFromFile')
+ load = TRANSFORMS.build(load)
+
+ albu_transform = dict(
+ type='Albu', transforms=[dict(type='ChannelShuffle', p=1)])
+ albu_transform = TRANSFORMS.build(albu_transform)
+
+ normalize = dict(type='Normalize', mean=[0] * 3, std=[0] * 3, to_rgb=True)
+ normalize = TRANSFORMS.build(normalize)
+
+ # Execute transforms
+ results = load(results)
+ results = albu_transform(results)
+ results = normalize(results)
+
+ assert results['img'].dtype == np.float32
+
+
+def test_albu_channel_order():
+ results = dict(
+ img_path=osp.join(osp.dirname(__file__), '../data/color.jpg'))
+
+ # Define simple pipeline
+ load = dict(type='LoadImageFromFile')
+ load = TRANSFORMS.build(load)
+
+ # Transform is modifying B channel
+ albu_transform = dict(
+ type='Albu',
+ transforms=[
+ dict(
+ type='RGBShift',
+ r_shift_limit=0,
+ g_shift_limit=0,
+ b_shift_limit=200,
+ p=1)
+ ])
+ albu_transform = TRANSFORMS.build(albu_transform)
+
+ # Execute transforms
+ results_load = load(results)
+ results_albu = albu_transform(results_load)
+
+ # assert only Green and Red channel are not modified
+ np.testing.assert_array_equal(results_albu['img'][..., 1:],
+ results_load['img'][..., 1:])
+
+ # assert Blue channel is modified
+ with pytest.raises(AssertionError):
+ np.testing.assert_array_equal(results_albu['img'][..., 0],
+ results_load['img'][..., 0])
+
+
+class TestRandomDepthMix(TestCase):
+
+ def setUp(self):
+ self.transform = RandomDepthMix(prob=1.0)
+
+ def test_transform_shape(self):
+ # Create a dummy result dict
+ results = {
+ 'img_shape': (10, 10),
+ 'img': np.random.rand(10, 10, 3),
+ 'gt_depth_map': np.random.rand(10, 10)
+ }
+ transformed = self.transform.transform(results)
+
+ # Check if the shape remains the same
+ self.assertEqual(results['img'].shape, transformed['img'].shape)
+
+ def test_transform_values(self):
+ # Create a dummy result dict
+ results = {
+ 'img_shape': (10, 10),
+ 'img': np.zeros((10, 10, 3)),
+ 'gt_depth_map': np.ones((10, 10))
+ }
+ transformed = self.transform.transform(results)
+
+ # Assuming the transformation modifies a portion of the image,
+ # it shouldn't remain all zeros
+ self.assertFalse(np.all(transformed['img'] == 0))
+
+ def test_invalid_image_dimension(self):
+ # Create a dummy result dict with invalid image dimension
+ results = {
+ 'img_shape': (10, 10),
+ 'img': np.random.rand(10, 10, 3, 3),
+ 'gt_depth_map': np.random.rand(10, 10)
+ }
+
+ # Check if a ValueError is raised for invalid dimension
+ with self.assertRaises(ValueError):
+ self.transform.transform(results)
diff --git a/tests/test_evaluation/test_metrics/test_depth_metric.py b/tests/test_evaluation/test_metrics/test_depth_metric.py
new file mode 100644
index 00000000000..a172db8fa20
--- /dev/null
+++ b/tests/test_evaluation/test_metrics/test_depth_metric.py
@@ -0,0 +1,85 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import shutil
+from unittest import TestCase
+
+import torch
+from mmengine.structures import PixelData
+
+from mmseg.evaluation import DepthMetric
+from mmseg.structures import SegDataSample
+
+
+class TestDepthMetric(TestCase):
+
+ def _demo_mm_inputs(self,
+ batch_size=2,
+ image_shapes=(3, 64, 64),
+ num_classes=5):
+ """Create a superset of inputs needed to run test or train batches.
+
+ Args:
+ batch_size (int): batch size. Default to 2.
+ image_shapes (List[tuple], Optional): image shape.
+ Default to (3, 64, 64)
+ num_classes (int): number of different classes.
+ Default to 5.
+ """
+ if isinstance(image_shapes, list):
+ assert len(image_shapes) == batch_size
+ else:
+ image_shapes = [image_shapes] * batch_size
+
+ data_samples = []
+ for idx in range(batch_size):
+ image_shape = image_shapes[idx]
+ _, h, w = image_shape
+
+ data_sample = SegDataSample()
+ gt_depth_map = torch.rand((1, h, w)) * 10
+ data_sample.gt_depth_map = PixelData(data=gt_depth_map)
+
+ data_samples.append(data_sample.to_dict())
+
+ return data_samples
+
+ def _demo_mm_model_output(self,
+ data_samples,
+ batch_size=2,
+ image_shapes=(3, 64, 64),
+ num_classes=5):
+
+ _, h, w = image_shapes
+
+ for data_sample in data_samples:
+ data_sample['pred_depth_map'] = dict(data=torch.randn(1, h, w))
+
+ data_sample[
+ 'img_path'] = 'tests/data/pseudo_dataset/imgs/00000_img.jpg'
+ return data_samples
+
+ def test_evaluate(self):
+ """Test using the metric in the same way as Evalutor."""
+
+ data_samples = self._demo_mm_inputs()
+ data_samples = self._demo_mm_model_output(data_samples)
+
+ depth_metric = DepthMetric()
+ depth_metric.process([0] * len(data_samples), data_samples)
+ res = depth_metric.compute_metrics(depth_metric.results)
+ self.assertIsInstance(res, dict)
+
+ # test save depth map file in output_dir
+ depth_metric = DepthMetric(output_dir='tmp')
+ depth_metric.process([0] * len(data_samples), data_samples)
+ assert osp.exists('tmp')
+ assert osp.isfile('tmp/00000_img.png')
+ shutil.rmtree('tmp')
+
+ # test format_only
+ depth_metric = DepthMetric(output_dir='tmp', format_only=True)
+ depth_metric.process([0] * len(data_samples), data_samples)
+ assert depth_metric.results == []
+ assert osp.exists('tmp')
+ assert osp.isfile('tmp/00000_img.png')
+ shutil.rmtree('tmp')
diff --git a/tests/test_models/test_assigners/test_hungarian_assigner.py b/tests/test_models/test_assigners/test_hungarian_assigner.py
new file mode 100644
index 00000000000..2cdb1de839d
--- /dev/null
+++ b/tests/test_models/test_assigners/test_hungarian_assigner.py
@@ -0,0 +1,77 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest import TestCase
+
+import torch
+from mmengine.structures import InstanceData
+
+from mmseg.models.assigners import HungarianAssigner
+
+
+class TestHungarianAssigner(TestCase):
+
+ def test_init(self):
+ with self.assertRaises(AssertionError):
+ HungarianAssigner([])
+
+ def test_hungarian_match_assigner(self):
+ assigner = HungarianAssigner([
+ dict(type='ClassificationCost', weight=2.0),
+ dict(type='CrossEntropyLossCost', weight=5.0, use_sigmoid=True),
+ dict(type='DiceCost', weight=5.0, pred_act=True, eps=1.0)
+ ])
+ num_classes = 3
+ num_masks = 10
+ num_points = 20
+ gt_instances = InstanceData()
+ gt_instances.labels = torch.randint(0, num_classes, (num_classes, ))
+ gt_instances.masks = torch.randint(0, 2, (num_classes, num_points))
+ pred_instances = InstanceData()
+ pred_instances.scores = torch.rand((num_masks, num_classes))
+ pred_instances.masks = torch.rand((num_masks, num_points))
+
+ matched_quiery_inds, matched_label_inds = \
+ assigner.assign(pred_instances, gt_instances)
+ unique_quiery_inds = torch.unique(matched_quiery_inds)
+ unique_label_inds = torch.unique(matched_label_inds)
+ self.assertTrue(len(unique_quiery_inds) == len(matched_quiery_inds))
+ self.assertTrue(
+ torch.equal(unique_label_inds, torch.arange(0, num_classes)))
+
+ def test_cls_match_cost(self):
+ num_classes = 3
+ num_masks = 10
+ gt_instances = InstanceData()
+ gt_instances.labels = torch.randint(0, num_classes, (num_classes, ))
+ pred_instances = InstanceData()
+ pred_instances.scores = torch.rand((num_masks, num_classes))
+
+ # test ClassificationCost
+ assigner = HungarianAssigner(dict(type='ClassificationCost'))
+ matched_quiery_inds, matched_label_inds = \
+ assigner.assign(pred_instances, gt_instances)
+ unique_quiery_inds = torch.unique(matched_quiery_inds)
+ unique_label_inds = torch.unique(matched_label_inds)
+ self.assertTrue(len(unique_quiery_inds) == len(matched_quiery_inds))
+ self.assertTrue(
+ torch.equal(unique_label_inds, torch.arange(0, num_classes)))
+
+ def test_mask_match_cost(self):
+ num_classes = 3
+ num_masks = 10
+ num_points = 20
+ gt_instances = InstanceData()
+ gt_instances.masks = torch.randint(0, 2, (num_classes, num_points))
+ pred_instances = InstanceData()
+ pred_instances.masks = torch.rand((num_masks, num_points))
+
+ # test DiceCost
+ assigner = HungarianAssigner(
+ dict(type='DiceCost', pred_act=True, eps=1.0))
+ assign_result = assigner.assign(pred_instances, gt_instances)
+ self.assertTrue(len(assign_result[0]) == len(assign_result[1]))
+
+ # test CrossEntropyLossCost
+ assigner = HungarianAssigner(
+ dict(type='CrossEntropyLossCost', use_sigmoid=True))
+ assign_result = assigner.assign(pred_instances, gt_instances)
+ self.assertTrue(len(assign_result[0]) == len(assign_result[1]))
diff --git a/tests/test_models/test_backbones/test_clip_text_encoder.py b/tests/test_models/test_backbones/test_clip_text_encoder.py
new file mode 100644
index 00000000000..ea06c5b5b3f
--- /dev/null
+++ b/tests/test_models/test_backbones/test_clip_text_encoder.py
@@ -0,0 +1,43 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmengine import Config
+from mmengine.registry import init_default_scope
+
+from mmseg.models.text_encoder import CLIPTextEncoder
+from mmseg.utils import get_classes
+
+
+def test_clip_text_encoder():
+ init_default_scope('mmseg')
+ # test vocabulary
+ output_dims = 8
+ embed_dims = 32
+ vocabulary = ['cat', 'dog', 'bird', 'car', 'bike']
+ cfg = dict(
+ vocabulary=vocabulary,
+ templates=['a photo of a {}.'],
+ embed_dims=embed_dims,
+ output_dims=output_dims)
+ cfg = Config(cfg)
+
+ text_encoder = CLIPTextEncoder(**cfg)
+ if torch.cuda.is_available():
+ text_encoder = text_encoder.cuda()
+
+ with torch.no_grad():
+ class_embeds = text_encoder()
+ assert class_embeds.shape == (len(vocabulary) + 1, output_dims)
+
+ # test dataset name
+ cfg = dict(
+ dataset_name='vaihingen',
+ templates=['a photo of a {}.'],
+ embed_dims=embed_dims,
+ output_dims=output_dims)
+ cfg = Config(cfg)
+
+ text_encoder = CLIPTextEncoder(**cfg)
+ with torch.no_grad():
+ class_embeds = text_encoder()
+ class_nums = len(get_classes('vaihingen'))
+ assert class_embeds.shape == (class_nums + 1, output_dims)
diff --git a/tests/test_models/test_backbones/test_vpd.py b/tests/test_models/test_backbones/test_vpd.py
new file mode 100644
index 00000000000..a268159155e
--- /dev/null
+++ b/tests/test_models/test_backbones/test_vpd.py
@@ -0,0 +1,51 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from os.path import dirname, join
+from unittest import TestCase
+
+import torch
+from mmengine import Config
+
+import mmseg
+from mmseg.models.backbones import VPD
+
+
+class TestVPD(TestCase):
+
+ def setUp(self) -> None:
+
+ repo_dpath = dirname(dirname(mmseg.__file__))
+ config_dpath = join(repo_dpath, 'configs/_base_/models/vpd_sd.py')
+ vpd_cfg = Config.fromfile(config_dpath).stable_diffusion_cfg
+ vpd_cfg.pop('checkpoint')
+
+ self.vpd_model = VPD(
+ diffusion_cfg=vpd_cfg,
+ class_embed_path='https://download.openmmlab.com/mmsegmentation/'
+ 'v0.5/vpd/nyu_class_embeddings.pth',
+ class_embed_select=True,
+ pad_shape=64,
+ unet_cfg=dict(use_attn=False),
+ )
+
+ def test_forward(self):
+ # test forward without class_id
+ x = torch.randn(1, 3, 60, 60)
+ with torch.no_grad():
+ out = self.vpd_model(x)
+
+ self.assertEqual(len(out), 4)
+ self.assertListEqual(list(out[0].shape), [1, 320, 8, 8])
+ self.assertListEqual(list(out[1].shape), [1, 640, 4, 4])
+ self.assertListEqual(list(out[2].shape), [1, 1280, 2, 2])
+ self.assertListEqual(list(out[3].shape), [1, 1280, 1, 1])
+
+ # test forward with class_id
+ x = torch.randn(1, 3, 60, 60)
+ with torch.no_grad():
+ out = self.vpd_model((x, torch.tensor([2])))
+
+ self.assertEqual(len(out), 4)
+ self.assertListEqual(list(out[0].shape), [1, 320, 8, 8])
+ self.assertListEqual(list(out[1].shape), [1, 640, 4, 4])
+ self.assertListEqual(list(out[2].shape), [1, 1280, 2, 2])
+ self.assertListEqual(list(out[3].shape), [1, 1280, 1, 1])
diff --git a/tests/test_models/test_heads/test_san_head.py b/tests/test_models/test_heads/test_san_head.py
new file mode 100644
index 00000000000..af85a6e2ca0
--- /dev/null
+++ b/tests/test_models/test_heads/test_san_head.py
@@ -0,0 +1,126 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmengine import Config
+from mmengine.structures import PixelData
+
+from mmseg.models.decode_heads import SideAdapterCLIPHead
+from mmseg.structures import SegDataSample
+from .utils import list_to_cuda
+
+
+def test_san_head():
+ H, W = (64, 64)
+ clip_channels = 64
+ img_channels = 4
+ num_queries = 40
+ out_dims = 64
+ num_classes = 19
+ cfg = dict(
+ num_classes=num_classes,
+ deep_supervision_idxs=[4],
+ san_cfg=dict(
+ in_channels=img_channels,
+ embed_dims=128,
+ clip_channels=clip_channels,
+ num_queries=num_queries,
+ cfg_encoder=dict(num_encode_layer=4, mlp_ratio=2, num_heads=2),
+ cfg_decoder=dict(
+ num_heads=4,
+ num_layers=1,
+ embed_channels=32,
+ mlp_channels=32,
+ num_mlp=2,
+ rescale=True)),
+ maskgen_cfg=dict(
+ sos_token_num=num_queries,
+ embed_dims=clip_channels,
+ out_dims=out_dims,
+ num_heads=4,
+ mlp_ratio=2),
+ train_cfg=dict(
+ num_points=100,
+ oversample_ratio=3.0,
+ importance_sample_ratio=0.75,
+ assigner=dict(
+ type='HungarianAssigner',
+ match_costs=[
+ dict(type='ClassificationCost', weight=2.0),
+ dict(
+ type='CrossEntropyLossCost',
+ weight=5.0,
+ use_sigmoid=True),
+ dict(type='DiceCost', weight=5.0, pred_act=True, eps=1.0)
+ ])),
+ loss_decode=[
+ dict(
+ type='CrossEntropyLoss',
+ loss_name='loss_cls_ce',
+ loss_weight=2.0,
+ class_weight=[1.0] * num_classes + [0.1]),
+ dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_name='loss_mask_ce',
+ loss_weight=5.0),
+ dict(
+ type='DiceLoss',
+ ignore_index=None,
+ naive_dice=True,
+ eps=1,
+ loss_name='loss_mask_dice',
+ loss_weight=5.0)
+ ])
+
+ cfg = Config(cfg)
+ head = SideAdapterCLIPHead(**cfg)
+
+ inputs = torch.rand((2, img_channels, H, W))
+ clip_feature = [[
+ torch.rand((2, clip_channels, H // 2, W // 2)),
+ torch.rand((2, clip_channels))
+ ],
+ [
+ torch.rand((2, clip_channels, H // 2, W // 2)),
+ torch.rand((2, clip_channels))
+ ],
+ [
+ torch.rand((2, clip_channels, H // 2, W // 2)),
+ torch.rand((2, clip_channels))
+ ],
+ [
+ torch.rand((2, clip_channels, H // 2, W // 2)),
+ torch.rand((2, clip_channels))
+ ]]
+ class_embed = torch.rand((num_classes + 1, out_dims))
+
+ data_samples = []
+ for i in range(2):
+ data_sample = SegDataSample()
+ img_meta = {}
+ img_meta['img_shape'] = (H, W)
+ img_meta['ori_shape'] = (H, W)
+ data_sample.gt_sem_seg = PixelData(
+ data=torch.randint(0, num_classes, (1, H, W)))
+ data_sample.set_metainfo(img_meta)
+ data_samples.append(data_sample)
+
+ batch_img_metas = []
+ for data_sample in data_samples:
+ batch_img_metas.append(data_sample.metainfo)
+
+ if torch.cuda.is_available():
+ head = head.cuda()
+ data = list_to_cuda([inputs, clip_feature, class_embed])
+ for data_sample in data_samples:
+ data_sample.gt_sem_seg.data = data_sample.gt_sem_seg.data.cuda()
+ else:
+ data = [inputs, clip_feature, class_embed]
+
+ # loss test
+ loss_dict = head.loss(data, data_samples, None)
+ assert isinstance(loss_dict, dict)
+
+ # prediction test
+ with torch.no_grad():
+ seg_logits = head.predict(data, batch_img_metas, None)
+ assert seg_logits.shape == torch.Size((2, num_classes, H, W))
diff --git a/tests/test_models/test_heads/test_vpd_depth_head.py b/tests/test_models/test_heads/test_vpd_depth_head.py
new file mode 100644
index 00000000000..e3a4f7558ec
--- /dev/null
+++ b/tests/test_models/test_heads/test_vpd_depth_head.py
@@ -0,0 +1,50 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest import TestCase
+
+import torch
+from mmengine.structures import PixelData
+
+from mmseg.models.decode_heads import VPDDepthHead
+from mmseg.structures import SegDataSample
+
+
+class TestVPDDepthHead(TestCase):
+
+ def setUp(self):
+ """Set up common resources."""
+ self.in_channels = [320, 640, 1280, 1280]
+ self.max_depth = 10.0
+ self.loss_decode = dict(
+ type='SiLogLoss'
+ ) # Replace with your actual loss type and parameters
+ self.vpd_depth_head = VPDDepthHead(
+ max_depth=self.max_depth,
+ in_channels=self.in_channels,
+ loss_decode=self.loss_decode)
+
+ def test_forward(self):
+ """Test the forward method."""
+ # Create a mock input tensor. Replace shape as per your needs.
+ x = [
+ torch.randn(1, 320, 32, 32),
+ torch.randn(1, 640, 16, 16),
+ torch.randn(1, 1280, 8, 8),
+ torch.randn(1, 1280, 4, 4)
+ ]
+
+ output = self.vpd_depth_head.forward(x)
+ print(output.shape)
+
+ self.assertEqual(output.shape, (1, 1, 256, 256))
+
+ def test_loss_by_feat(self):
+ """Test the loss_by_feat method."""
+ # Create mock data for `pred_depth_map` and `batch_data_samples`.
+ pred_depth_map = torch.randn(1, 1, 32, 32)
+ gt_depth_map = PixelData(data=torch.rand(1, 32, 32))
+ batch_data_samples = [SegDataSample(gt_depth_map=gt_depth_map)]
+
+ loss = self.vpd_depth_head.loss_by_feat(pred_depth_map,
+ batch_data_samples)
+
+ self.assertIsNotNone(loss)
diff --git a/tests/test_models/test_heads/utils.py b/tests/test_models/test_heads/utils.py
index 335e261a5e5..72823401552 100644
--- a/tests/test_models/test_heads/utils.py
+++ b/tests/test_models/test_heads/utils.py
@@ -20,3 +20,12 @@ def to_cuda(module, data):
for i in range(len(data)):
data[i] = data[i].cuda()
return module, data
+
+
+def list_to_cuda(data):
+ if isinstance(data, list):
+ for i in range(len(data)):
+ data[i] = list_to_cuda(data[i])
+ return data
+ else:
+ return data.cuda()
diff --git a/tests/test_models/test_losses/test_dice_loss.py b/tests/test_models/test_losses/test_dice_loss.py
new file mode 100644
index 00000000000..34253dae12e
--- /dev/null
+++ b/tests/test_models/test_losses/test_dice_loss.py
@@ -0,0 +1,96 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import pytest
+import torch
+
+from mmseg.models.losses import DiceLoss
+
+
+@pytest.mark.parametrize('naive_dice', [True, False])
+def test_dice_loss(naive_dice):
+ loss_class = DiceLoss
+ pred = torch.rand((1, 10, 4, 4))
+ target = torch.randint(0, 10, (1, 4, 4))
+ weight = torch.rand(1)
+ # Test loss forward
+ loss = loss_class(naive_dice=naive_dice)(pred, target)
+ assert isinstance(loss, torch.Tensor)
+
+ # Test loss forward with weight
+ loss = loss_class(naive_dice=naive_dice)(pred, target, weight)
+ assert isinstance(loss, torch.Tensor)
+
+ # Test loss forward with reduction_override
+ loss = loss_class(naive_dice=naive_dice)(
+ pred, target, reduction_override='mean')
+ assert isinstance(loss, torch.Tensor)
+
+ # Test loss forward with avg_factor
+ loss = loss_class(naive_dice=naive_dice)(pred, target, avg_factor=10)
+ assert isinstance(loss, torch.Tensor)
+
+ with pytest.raises(ValueError):
+ # loss can evaluate with avg_factor only if
+ # reduction is None, 'none' or 'mean'.
+ reduction_override = 'sum'
+ loss_class(naive_dice=naive_dice)(
+ pred, target, avg_factor=10, reduction_override=reduction_override)
+
+ # Test loss forward with avg_factor and reduction
+ for reduction_override in [None, 'none', 'mean']:
+ loss_class(naive_dice=naive_dice)(
+ pred, target, avg_factor=10, reduction_override=reduction_override)
+ assert isinstance(loss, torch.Tensor)
+
+ # Test loss forward with has_acted=False and use_sigmoid=False
+ for use_sigmoid in [True, False]:
+ loss_class(
+ use_sigmoid=use_sigmoid, activate=True,
+ naive_dice=naive_dice)(pred, target)
+ assert isinstance(loss, torch.Tensor)
+
+ # Test loss forward with weight.ndim != loss.ndim
+ with pytest.raises(AssertionError):
+ weight = torch.rand((2, 8))
+ loss_class(naive_dice=naive_dice)(pred, target, weight)
+
+ # Test loss forward with len(weight) != len(pred)
+ with pytest.raises(AssertionError):
+ weight = torch.rand(8)
+ loss_class(naive_dice=naive_dice)(pred, target, weight)
+
+ # Test _expand_onehot_labels_dice
+ pred = torch.tensor([[[[1, 1], [1, 0]], [[0, 1], [1, 1]]]]).float()
+ target = torch.tensor([[[0, 0], [0, 1]]])
+ target_onehot = torch.tensor([[[[1, 1], [1, 0]], [[0, 0], [0, 1]]]])
+ weight = torch.rand(1)
+ loss = loss_class(naive_dice=naive_dice)(pred, target, weight)
+ loss_onehot = loss_class(naive_dice=naive_dice)(pred, target_onehot,
+ weight)
+ assert torch.equal(loss, loss_onehot)
+
+ # Test Whether Loss is 0 when pred == target, eps == 0 and naive_dice=False
+ target = torch.randint(0, 2, (1, 10, 4, 4))
+ pred = target.float()
+ target = target.sigmoid()
+ weight = torch.rand(1)
+ loss = loss_class(
+ naive_dice=False, use_sigmoid=True, eps=0)(pred, target, weight)
+ assert loss.item() == 0
+
+ # Test ignore_index when ignore_index is the only class
+ with pytest.raises(AssertionError):
+ pred = torch.ones((1, 1, 4, 4))
+ target = torch.randint(0, 1, (1, 4, 4))
+ weight = torch.rand(1)
+ loss = loss_class(
+ naive_dice=naive_dice, use_sigmoid=False, ignore_index=0,
+ eps=0)(pred, target, weight)
+
+ # Test ignore_index with naive_dice = False
+ pred = torch.tensor([[[[1, 1], [1, 0]], [[0, 1], [1, 1]]]]).float()
+ target = torch.tensor([[[[1, 1], [1, 0]], [[1, 0], [0, 1]]]]).sigmoid()
+ weight = torch.rand(1)
+ loss = loss_class(
+ naive_dice=False, use_sigmoid=True, ignore_index=1,
+ eps=0)(pred, target, weight)
+ assert loss.item() == 0
diff --git a/tests/test_models/test_losses/test_huasdorff_distance_loss.py b/tests/test_models/test_losses/test_huasdorff_distance_loss.py
new file mode 100644
index 00000000000..29c2732d3f1
--- /dev/null
+++ b/tests/test_models/test_losses/test_huasdorff_distance_loss.py
@@ -0,0 +1,29 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import pytest
+import torch
+
+from mmseg.models.losses import HuasdorffDisstanceLoss
+
+
+def test_huasdorff_distance_loss():
+ loss_class = HuasdorffDisstanceLoss
+ pred = torch.rand((10, 8, 6, 6))
+ target = torch.rand((10, 6, 6))
+ class_weight = torch.rand(8)
+
+ # Test loss forward
+ loss = loss_class()(pred, target)
+ assert isinstance(loss, torch.Tensor)
+
+ # Test loss forward with avg_factor
+ loss = loss_class()(pred, target, avg_factor=10)
+ assert isinstance(loss, torch.Tensor)
+
+ # Test loss forward with avg_factor and reduction is None, 'sum' and 'mean'
+ for reduction in [None, 'sum', 'mean']:
+ loss = loss_class()(pred, target, avg_factor=10, reduction=reduction)
+ assert isinstance(loss, torch.Tensor)
+
+ # Test loss forward with class_weight
+ with pytest.raises(AssertionError):
+ loss_class(class_weight=class_weight)(pred, target)
diff --git a/tests/test_models/test_losses/test_kldiv_loss.py b/tests/test_models/test_losses/test_kldiv_loss.py
new file mode 100644
index 00000000000..48bcc4bfd9f
--- /dev/null
+++ b/tests/test_models/test_losses/test_kldiv_loss.py
@@ -0,0 +1,40 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmseg.models.losses.kldiv_loss import KLDivLoss
+
+
+def test_kldiv_loss_with_none_reduction():
+ loss_class = KLDivLoss
+ pred = torch.rand((8, 5, 5))
+ target = torch.rand((8, 5, 5))
+ reduction = 'none'
+
+ # Test loss forward
+ loss = loss_class(reduction=reduction)(pred, target)
+ assert isinstance(loss, torch.Tensor)
+ assert loss.shape == (8, 5, 5), f'{loss.shape}'
+
+
+def test_kldiv_loss_with_mean_reduction():
+ loss_class = KLDivLoss
+ pred = torch.rand((8, 5, 5))
+ target = torch.rand((8, 5, 5))
+ reduction = 'mean'
+
+ # Test loss forward
+ loss = loss_class(reduction=reduction)(pred, target)
+ assert isinstance(loss, torch.Tensor)
+ assert loss.shape == (8, ), f'{loss.shape}'
+
+
+def test_kldiv_loss_with_sum_reduction():
+ loss_class = KLDivLoss
+ pred = torch.rand((8, 5, 5))
+ target = torch.rand((8, 5, 5))
+ reduction = 'sum'
+
+ # Test loss forward
+ loss = loss_class(reduction=reduction)(pred, target)
+ assert isinstance(loss, torch.Tensor)
+ assert loss.shape == (8, ), f'{loss.shape}'
diff --git a/tests/test_models/test_losses/test_silog_loss.py b/tests/test_models/test_losses/test_silog_loss.py
new file mode 100644
index 00000000000..022434bcc14
--- /dev/null
+++ b/tests/test_models/test_losses/test_silog_loss.py
@@ -0,0 +1,20 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest import TestCase
+
+import torch
+
+from mmseg.models.losses import SiLogLoss
+
+
+class TestSiLogLoss(TestCase):
+
+ def test_SiLogLoss_forward(self):
+ pred = torch.tensor([[1.0, 2.0], [3.5, 4.0]], dtype=torch.float32)
+ target = torch.tensor([[0.0, 2.0], [3.0, 4.0]], dtype=torch.float32)
+ weight = torch.tensor([1.0, 0.5], dtype=torch.float32)
+
+ loss_module = SiLogLoss()
+ loss = loss_module.forward(pred, target, weight)
+
+ expected_loss = 0.02
+ self.assertAlmostEqual(loss.item(), expected_loss, places=2)
diff --git a/tests/test_models/test_segmentors/test_depth_estimator.py b/tests/test_models/test_segmentors/test_depth_estimator.py
new file mode 100644
index 00000000000..e819c9e7633
--- /dev/null
+++ b/tests/test_models/test_segmentors/test_depth_estimator.py
@@ -0,0 +1,64 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from copy import deepcopy
+from os.path import dirname, join
+from unittest import TestCase
+
+import torch
+from mmengine import Config, ConfigDict
+from mmengine.structures import PixelData
+
+import mmseg
+from mmseg.models.segmentors import DepthEstimator
+from mmseg.structures import SegDataSample
+
+
+class TestDepthEstimator(TestCase):
+
+ def setUp(self) -> None:
+ repo_dpath = dirname(dirname(mmseg.__file__))
+ config_dpath = join(repo_dpath, 'configs/_base_/models/vpd_sd.py')
+ vpd_cfg = Config.fromfile(config_dpath).stable_diffusion_cfg
+ vpd_cfg.pop('checkpoint')
+
+ backbone_cfg = dict(
+ type='VPD',
+ diffusion_cfg=vpd_cfg,
+ class_embed_path='https://download.openmmlab.com/mmsegmentation/'
+ 'v0.5/vpd/nyu_class_embeddings.pth',
+ class_embed_select=True,
+ pad_shape=64,
+ unet_cfg=dict(use_attn=False),
+ )
+
+ head_cfg = dict(
+ type='VPDDepthHead',
+ max_depth=10,
+ )
+
+ self.model = DepthEstimator(
+ backbone=backbone_cfg, decode_head=head_cfg)
+
+ inputs = torch.randn(1, 3, 64, 80)
+ data_sample = SegDataSample()
+ data_sample.gt_depth_map = PixelData(data=torch.rand(1, 64, 80))
+ data_sample.set_metainfo(dict(img_shape=(64, 80), ori_shape=(64, 80)))
+ self.data = dict(inputs=inputs, data_samples=[data_sample])
+
+ def test_slide_flip_inference(self):
+
+ self.model.test_cfg = ConfigDict(
+ dict(mode='slide_flip', crop_size=(64, 64), stride=(16, 16)))
+
+ with torch.no_grad():
+ out = self.model.predict(**deepcopy(self.data))
+
+ self.assertEqual(len(out), 1)
+ self.assertIn('pred_depth_map', out[0].keys())
+ self.assertListEqual(list(out[0].pred_depth_map.shape), [64, 80])
+
+ def test__forward(self):
+ data = deepcopy(self.data)
+ data['inputs'] = data['inputs'][:, :, :64, :64]
+ with torch.no_grad():
+ out = self.model._forward(**data)
+ self.assertListEqual(list(out.shape), [1, 1, 64, 64])
diff --git a/tests/test_models/test_segmentors/test_multimodal_encoder_decoder.py b/tests/test_models/test_segmentors/test_multimodal_encoder_decoder.py
new file mode 100644
index 00000000000..75258d89a7d
--- /dev/null
+++ b/tests/test_models/test_segmentors/test_multimodal_encoder_decoder.py
@@ -0,0 +1,24 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmengine import ConfigDict
+
+from mmseg.models import build_segmentor
+from tests.test_models.test_segmentors.utils import \
+ _segmentor_forward_train_test
+
+
+def test_multimodal_encoder_decoder():
+
+ cfg = ConfigDict(
+ type='MultimodalEncoderDecoder',
+ asymetric_input=False,
+ image_encoder=dict(type='ExampleBackbone', out_indices=[1, 2, 3, 4]),
+ text_encoder=dict(
+ type='ExampleTextEncoder',
+ vocabulary=['A', 'B', 'C'],
+ output_dims=3),
+ decode_head=dict(
+ type='ExampleDecodeHead', out_channels=1, num_classes=2),
+ train_cfg=None,
+ test_cfg=dict(mode='whole'))
+ segmentor = build_segmentor(cfg)
+ _segmentor_forward_train_test(segmentor)
diff --git a/tests/test_models/test_segmentors/test_seg_tta_model.py b/tests/test_models/test_segmentors/test_seg_tta_model.py
index 3c9699e8df4..1e152ed0565 100644
--- a/tests/test_models/test_segmentors/test_seg_tta_model.py
+++ b/tests/test_models/test_segmentors/test_seg_tta_model.py
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import tempfile
+
import torch
from mmengine import ConfigDict
from mmengine.model import BaseTTAModel
@@ -37,7 +39,8 @@ def test_encoder_decoder_tta():
ori_shape=(10, 10),
img_shape=(10 + i, 10 + i),
flip=(i % 2 == 0),
- flip_direction=flip_direction),
+ flip_direction=flip_direction,
+ img_path=tempfile.mktemp()),
gt_sem_seg=PixelData(data=torch.randint(0, 19, (1, 10, 10))))
])
diff --git a/tests/test_models/test_segmentors/utils.py b/tests/test_models/test_segmentors/utils.py
index 6b440df906d..ac31e2b2774 100644
--- a/tests/test_models/test_segmentors/utils.py
+++ b/tests/test_models/test_segmentors/utils.py
@@ -52,15 +52,22 @@ def _demo_mm_inputs(input_shape=(1, 3, 8, 16), num_classes=10):
@MODELS.register_module()
class ExampleBackbone(nn.Module):
- def __init__(self):
+ def __init__(self, out_indices=None):
super().__init__()
self.conv = nn.Conv2d(3, 3, 3)
+ self.out_indices = out_indices
def init_weights(self, pretrained=None):
pass
def forward(self, x):
- return [self.conv(x)]
+ if self.out_indices is None:
+ return [self.conv(x)]
+ else:
+ outs = []
+ for i in self.out_indices:
+ outs.append(self.conv(x))
+ return outs
@MODELS.register_module()
@@ -74,6 +81,18 @@ def forward(self, inputs):
return self.cls_seg(inputs[0])
+@MODELS.register_module()
+class ExampleTextEncoder(nn.Module):
+
+ def __init__(self, vocabulary=None, output_dims=None):
+ super().__init__()
+ self.vocabulary = vocabulary
+ self.output_dims = output_dims
+
+ def forward(self):
+ return torch.randn((len(self.vocabulary), self.output_dims))
+
+
@MODELS.register_module()
class ExampleCascadeDecodeHead(BaseCascadeDecodeHead):
@@ -132,3 +151,32 @@ def _segmentor_forward_train_test(segmentor):
data_batch = dict(inputs=imgs, data_samples=data_samples)
results = segmentor.forward(imgs, data_samples, mode='tensor')
assert isinstance(results, torch.Tensor)
+
+
+def _segmentor_predict(segmentor):
+ if isinstance(segmentor.decode_head, nn.ModuleList):
+ num_classes = segmentor.decode_head[-1].num_classes
+ else:
+ num_classes = segmentor.decode_head.num_classes
+ # batch_size=2 for BatchNorm
+ mm_inputs = _demo_mm_inputs(num_classes=num_classes)
+
+ # convert to cuda Tensor if applicable
+ if torch.cuda.is_available():
+ segmentor = segmentor.cuda()
+
+ # check data preprocessor
+ if not hasattr(segmentor,
+ 'data_preprocessor') or segmentor.data_preprocessor is None:
+ segmentor.data_preprocessor = SegDataPreProcessor()
+
+ mm_inputs = segmentor.data_preprocessor(mm_inputs, True)
+ imgs = mm_inputs.pop('imgs')
+ data_samples = mm_inputs.pop('data_samples')
+
+ # Test predict
+ with torch.no_grad():
+ segmentor.eval()
+ data_batch = dict(inputs=imgs, data_samples=data_samples)
+ outputs = segmentor.predict(**data_batch)
+ assert isinstance(outputs, list)
diff --git a/tests/test_visualization/test_local_visualizer.py b/tests/test_visualization/test_local_visualizer.py
index b60a9b87507..e3b2a88cfb7 100644
--- a/tests/test_visualization/test_local_visualizer.py
+++ b/tests/test_visualization/test_local_visualizer.py
@@ -155,3 +155,59 @@ def _assert_image_and_shape(self, out_file, out_shape):
assert os.path.exists(out_file)
drawn_img = cv2.imread(out_file)
assert drawn_img.shape == out_shape
+
+ def test_add_datasample_depth(self):
+ h = 10
+ w = 12
+ out_file = 'out_file'
+
+ image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8')
+
+ # test gt_depth_map
+ gt_depth_map = PixelData(data=torch.rand(1, h, w))
+
+ def test_add_datasample_forward_depth(gt_depth_map):
+ data_sample = SegDataSample()
+ data_sample.gt_depth_map = gt_depth_map
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ seg_local_visualizer = SegLocalVisualizer(
+ vis_backends=[dict(type='LocalVisBackend')],
+ save_dir=tmp_dir)
+ seg_local_visualizer.dataset_meta = dict(
+ classes=('background', 'foreground'),
+ palette=[[120, 120, 120], [6, 230, 230]])
+
+ # test out_file
+ seg_local_visualizer.add_datasample(out_file, image,
+ data_sample)
+
+ assert os.path.exists(
+ osp.join(tmp_dir, 'vis_data', 'vis_image',
+ out_file + '_0.png'))
+ drawn_img = cv2.imread(
+ osp.join(tmp_dir, 'vis_data', 'vis_image',
+ out_file + '_0.png'))
+ assert drawn_img.shape == (h * 2, w, 3)
+
+ # test gt_instances and pred_instances
+
+ pred_depth_map = PixelData(data=torch.rand(1, h, w))
+
+ data_sample.pred_depth_map = pred_depth_map
+
+ seg_local_visualizer.add_datasample(out_file, image,
+ data_sample)
+ self._assert_image_and_shape(
+ osp.join(tmp_dir, 'vis_data', 'vis_image',
+ out_file + '_0.png'), (h * 2, w * 2, 3))
+
+ seg_local_visualizer.add_datasample(
+ out_file, image, data_sample, draw_gt=False)
+ self._assert_image_and_shape(
+ osp.join(tmp_dir, 'vis_data', 'vis_image',
+ out_file + '_0.png'), (h * 2, w, 3))
+
+ if torch.cuda.is_available():
+ test_add_datasample_forward_depth(gt_depth_map.cuda())
+ test_add_datasample_forward_depth(gt_depth_map)
diff --git a/tools/analysis_tools/confusion_matrix.py b/tools/analysis_tools/confusion_matrix.py
index 9a87bc14c9b..39756cdfdd2 100644
--- a/tools/analysis_tools/confusion_matrix.py
+++ b/tools/analysis_tools/confusion_matrix.py
@@ -5,10 +5,14 @@
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator
-from mmengine import Config, DictAction
-from mmengine.utils import ProgressBar, load
+from mmengine.config import Config, DictAction
+from mmengine.registry import init_default_scope
+from mmengine.utils import mkdir_or_exist, progressbar
+from PIL import Image
-from mmseg.datasets import build_dataset
+from mmseg.registry import DATASETS
+
+init_default_scope('mmseg')
def parse_args():
@@ -16,7 +20,7 @@ def parse_args():
description='Generate confusion matrix from segmentation results')
parser.add_argument('config', help='test config file path')
parser.add_argument(
- 'prediction_path', help='prediction path where test .pkl result')
+ 'prediction_path', help='prediction path where test folder result')
parser.add_argument(
'save_dir', help='directory where confusion matrix will be saved')
parser.add_argument(
@@ -50,15 +54,23 @@ def calculate_confusion_matrix(dataset, results):
dataset (Dataset): Test or val dataset.
results (list[ndarray]): A list of segmentation results in each image.
"""
- n = len(dataset.CLASSES)
+ n = len(dataset.METAINFO['classes'])
confusion_matrix = np.zeros(shape=[n, n])
assert len(dataset) == len(results)
- prog_bar = ProgressBar(len(results))
+ ignore_index = dataset.ignore_index
+ reduce_zero_label = dataset.reduce_zero_label
+ prog_bar = progressbar.ProgressBar(len(results))
for idx, per_img_res in enumerate(results):
res_segm = per_img_res
- gt_segm = dataset.get_gt_seg_map_by_idx(idx)
+ gt_segm = dataset[idx]['data_samples'] \
+ .gt_sem_seg.data.squeeze().numpy().astype(np.uint8)
+ gt_segm, res_segm = gt_segm.flatten(), res_segm.flatten()
+ if reduce_zero_label:
+ gt_segm = gt_segm - 1
+ to_ignore = gt_segm == ignore_index
+
+ gt_segm, res_segm = gt_segm[~to_ignore], res_segm[~to_ignore]
inds = n * gt_segm + res_segm
- inds = inds.flatten()
mat = np.bincount(inds, minlength=n**2).reshape(n, n)
confusion_matrix += mat
prog_bar.update()
@@ -70,7 +82,7 @@ def plot_confusion_matrix(confusion_matrix,
save_dir=None,
show=True,
title='Normalized Confusion Matrix',
- color_theme='winter'):
+ color_theme='OrRd'):
"""Draw confusion matrix with matplotlib.
Args:
@@ -89,14 +101,15 @@ def plot_confusion_matrix(confusion_matrix,
num_classes = len(labels)
fig, ax = plt.subplots(
- figsize=(2 * num_classes, 2 * num_classes * 0.8), dpi=180)
+ figsize=(2 * num_classes, 2 * num_classes * 0.8), dpi=300)
cmap = plt.get_cmap(color_theme)
im = ax.imshow(confusion_matrix, cmap=cmap)
- plt.colorbar(mappable=im, ax=ax)
+ colorbar = plt.colorbar(mappable=im, ax=ax)
+ colorbar.ax.tick_params(labelsize=20) # 设置 colorbar 标签的字体大小
- title_font = {'weight': 'bold', 'size': 12}
+ title_font = {'weight': 'bold', 'size': 20}
ax.set_title(title, fontdict=title_font)
- label_font = {'size': 10}
+ label_font = {'size': 40}
plt.ylabel('Ground Truth Label', fontdict=label_font)
plt.xlabel('Prediction Label', fontdict=label_font)
@@ -116,8 +129,8 @@ def plot_confusion_matrix(confusion_matrix,
# draw label
ax.set_xticks(np.arange(num_classes))
ax.set_yticks(np.arange(num_classes))
- ax.set_xticklabels(labels)
- ax.set_yticklabels(labels)
+ ax.set_xticklabels(labels, fontsize=20)
+ ax.set_yticklabels(labels, fontsize=20)
ax.tick_params(
axis='x', bottom=False, top=True, labelbottom=False, labeltop=True)
@@ -135,13 +148,14 @@ def plot_confusion_matrix(confusion_matrix,
) if not np.isnan(confusion_matrix[i, j]) else -1),
ha='center',
va='center',
- color='w',
- size=7)
+ color='k',
+ size=20)
ax.set_ylim(len(confusion_matrix) - 0.5, -0.5) # matplotlib>3.1.1
fig.tight_layout()
if save_dir is not None:
+ mkdir_or_exist(save_dir)
plt.savefig(
os.path.join(save_dir, 'confusion_matrix.png'), format='png')
if show:
@@ -155,7 +169,12 @@ def main():
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
- results = load(args.prediction_path)
+ results = []
+ for img in sorted(os.listdir(args.prediction_path)):
+ img = os.path.join(args.prediction_path, img)
+ image = Image.open(img)
+ image = np.copy(image)
+ results.append(image)
assert isinstance(results, list)
if isinstance(results[0], np.ndarray):
@@ -163,17 +182,11 @@ def main():
else:
raise TypeError('invalid type of prediction results')
- if isinstance(cfg.data.test, dict):
- cfg.data.test.test_mode = True
- elif isinstance(cfg.data.test, list):
- for ds_cfg in cfg.data.test:
- ds_cfg.test_mode = True
-
- dataset = build_dataset(cfg.data.test)
+ dataset = DATASETS.build(cfg.test_dataloader.dataset)
confusion_matrix = calculate_confusion_matrix(dataset, results)
plot_confusion_matrix(
confusion_matrix,
- dataset.CLASSES,
+ dataset.METAINFO['classes'],
save_dir=args.save_dir,
show=args.show,
title=args.title,
diff --git a/tools/analysis_tools/visualization_cam.py b/tools/analysis_tools/visualization_cam.py
new file mode 100644
index 00000000000..00cdb3e04ab
--- /dev/null
+++ b/tools/analysis_tools/visualization_cam.py
@@ -0,0 +1,127 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""Use the pytorch-grad-cam tool to visualize Class Activation Maps (CAM).
+
+requirement: pip install grad-cam
+"""
+
+from argparse import ArgumentParser
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from mmengine import Config
+from mmengine.model import revert_sync_batchnorm
+from PIL import Image
+from pytorch_grad_cam import GradCAM
+from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image
+
+from mmseg.apis import inference_model, init_model, show_result_pyplot
+from mmseg.utils import register_all_modules
+
+
+class SemanticSegmentationTarget:
+ """wrap the model.
+
+ requirement: pip install grad-cam
+
+ Args:
+ category (int): Visualization class.
+ mask (ndarray): Mask of class.
+ size (tuple): Image size.
+ """
+
+ def __init__(self, category, mask, size):
+ self.category = category
+ self.mask = torch.from_numpy(mask)
+ self.size = size
+ if torch.cuda.is_available():
+ self.mask = self.mask.cuda()
+
+ def __call__(self, model_output):
+ model_output = torch.unsqueeze(model_output, dim=0)
+ model_output = F.interpolate(
+ model_output, size=self.size, mode='bilinear')
+ model_output = torch.squeeze(model_output, dim=0)
+
+ return (model_output[self.category, :, :] * self.mask).sum()
+
+
+def main():
+ parser = ArgumentParser()
+ parser.add_argument('img', help='Image file')
+ parser.add_argument('config', help='Config file')
+ parser.add_argument('checkpoint', help='Checkpoint file')
+ parser.add_argument(
+ '--out-file',
+ default='prediction.png',
+ help='Path to output prediction file')
+ parser.add_argument(
+ '--cam-file', default='vis_cam.png', help='Path to output cam file')
+ parser.add_argument(
+ '--target-layers',
+ default='backbone.layer4[2]',
+ help='Target layers to visualize CAM')
+ parser.add_argument(
+ '--category-index', default='7', help='Category to visualize CAM')
+ parser.add_argument(
+ '--device', default='cuda:0', help='Device used for inference')
+ args = parser.parse_args()
+
+ # build the model from a config file and a checkpoint file
+ register_all_modules()
+ model = init_model(args.config, args.checkpoint, device=args.device)
+ if args.device == 'cpu':
+ model = revert_sync_batchnorm(model)
+
+ # test a single image
+ result = inference_model(model, args.img)
+
+ # show the results
+ show_result_pyplot(
+ model,
+ args.img,
+ result,
+ draw_gt=False,
+ show=False if args.out_file is not None else True,
+ out_file=args.out_file)
+
+ # result data conversion
+ prediction_data = result.pred_sem_seg.data
+ pre_np_data = prediction_data.cpu().numpy().squeeze(0)
+
+ target_layers = args.target_layers
+ target_layers = [eval(f'model.{target_layers}')]
+
+ category = int(args.category_index)
+ mask_float = np.float32(pre_np_data == category)
+
+ # data processing
+ image = np.array(Image.open(args.img).convert('RGB'))
+ height, width = image.shape[0], image.shape[1]
+ rgb_img = np.float32(image) / 255
+ config = Config.fromfile(args.config)
+ image_mean = config.data_preprocessor['mean']
+ image_std = config.data_preprocessor['std']
+ input_tensor = preprocess_image(
+ rgb_img,
+ mean=[x / 255 for x in image_mean],
+ std=[x / 255 for x in image_std])
+
+ # Grad CAM(Class Activation Maps)
+ # Can also be LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM
+ targets = [
+ SemanticSegmentationTarget(category, mask_float, (height, width))
+ ]
+ with GradCAM(
+ model=model,
+ target_layers=target_layers,
+ use_cuda=torch.cuda.is_available()) as cam:
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
+ cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
+
+ # save cam file
+ Image.fromarray(cam_image).save(args.cam_file)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/dataset_converters/isaid.py b/tools/dataset_converters/isaid.py
index 1da264d975e..1d5ccd9c776 100644
--- a/tools/dataset_converters/isaid.py
+++ b/tools/dataset_converters/isaid.py
@@ -91,7 +91,7 @@ def slide_crop_image(src_path, out_dir, mode, patch_H, patch_W, overlap):
x_end) + '.png'
# print(image)
save_path_image = osp.join(out_dir, 'img_dir', mode, str(image))
- img_patch.save(save_path_image)
+ img_patch.save(save_path_image, format='BMP')
def slide_crop_label(src_path, out_dir, mode, patch_H, patch_W, overlap):
diff --git a/tools/dataset_converters/levircd.py b/tools/dataset_converters/levircd.py
new file mode 100644
index 00000000000..8717f3e856b
--- /dev/null
+++ b/tools/dataset_converters/levircd.py
@@ -0,0 +1,99 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import glob
+import math
+import os
+import os.path as osp
+
+import mmcv
+import numpy as np
+from mmengine.utils import ProgressBar
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description='Convert levir-cd dataset to mmsegmentation format')
+ parser.add_argument('--dataset_path', help='potsdam folder path')
+ parser.add_argument('-o', '--out_dir', help='output path')
+ parser.add_argument(
+ '--clip_size',
+ type=int,
+ help='clipped size of image after preparation',
+ default=256)
+ parser.add_argument(
+ '--stride_size',
+ type=int,
+ help='stride of clipping original images',
+ default=256)
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+ input_folder = args.dataset_path
+ png_files = glob.glob(
+ os.path.join(input_folder, '**/*.png'), recursive=True)
+ output_folder = args.out_dir
+ prog_bar = ProgressBar(len(png_files))
+ for png_file in png_files:
+ new_path = os.path.join(
+ output_folder,
+ os.path.relpath(os.path.dirname(png_file), input_folder))
+ os.makedirs(os.path.dirname(new_path), exist_ok=True)
+ label = False
+ if 'label' in png_file:
+ label = True
+ clip_big_image(png_file, new_path, args, label)
+ prog_bar.update()
+
+
+def clip_big_image(image_path, clip_save_dir, args, to_label=False):
+ image = mmcv.imread(image_path)
+
+ h, w, c = image.shape
+ clip_size = args.clip_size
+ stride_size = args.stride_size
+
+ num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil(
+ (h - clip_size) /
+ stride_size) * stride_size + clip_size >= h else math.ceil(
+ (h - clip_size) / stride_size) + 1
+ num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil(
+ (w - clip_size) /
+ stride_size) * stride_size + clip_size >= w else math.ceil(
+ (w - clip_size) / stride_size) + 1
+
+ x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
+ xmin = x * clip_size
+ ymin = y * clip_size
+
+ xmin = xmin.ravel()
+ ymin = ymin.ravel()
+ xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size,
+ np.zeros_like(xmin))
+ ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size,
+ np.zeros_like(ymin))
+ boxes = np.stack([
+ xmin + xmin_offset, ymin + ymin_offset,
+ np.minimum(xmin + clip_size, w),
+ np.minimum(ymin + clip_size, h)
+ ],
+ axis=1)
+
+ if to_label:
+ image[image == 255] = 1
+ image = image[:, :, 0]
+ for box in boxes:
+ start_x, start_y, end_x, end_y = box
+ clipped_image = image[start_y:end_y, start_x:end_x] \
+ if to_label else image[start_y:end_y, start_x:end_x, :]
+ idx = osp.basename(image_path).split('.')[0]
+ mmcv.imwrite(
+ clipped_image.astype(np.uint8),
+ osp.join(clip_save_dir,
+ f'{idx}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/dataset_converters/nyu.py b/tools/dataset_converters/nyu.py
new file mode 100644
index 00000000000..49e09e7af68
--- /dev/null
+++ b/tools/dataset_converters/nyu.py
@@ -0,0 +1,89 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import os.path as osp
+import shutil
+import tempfile
+import zipfile
+
+from mmengine.utils import mkdir_or_exist
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description='Convert NYU Depth dataset to mmsegmentation format')
+ parser.add_argument('raw_data', help='the path of raw data')
+ parser.add_argument(
+ '-o', '--out_dir', help='output path', default='./data/nyu')
+ args = parser.parse_args()
+ return args
+
+
+def reorganize(raw_data_dir: str, out_dir: str):
+ """Reorganize NYU Depth dataset files into the required directory
+ structure.
+
+ Args:
+ raw_data_dir (str): Path to the raw data directory.
+ out_dir (str): Output directory for the organized dataset.
+ """
+
+ def move_data(data_list, dst_prefix, fname_func):
+ """Move data files from source to destination directory.
+
+ Args:
+ data_list (list): List of data file paths.
+ dst_prefix (str): Prefix to be added to destination paths.
+ fname_func (callable): Function to process file names
+ """
+ for data_item in data_list:
+ data_item = data_item.strip().strip('/')
+ new_item = fname_func(data_item)
+ shutil.move(
+ osp.join(raw_data_dir, data_item),
+ osp.join(out_dir, dst_prefix, new_item))
+
+ def process_phase(phase):
+ """Process a dataset phase (e.g., 'train' or 'test')."""
+ with open(osp.join(raw_data_dir, f'nyu_{phase}.txt')) as f:
+ data = filter(lambda x: len(x.strip()) > 0, f.readlines())
+ data = map(lambda x: x.split()[:2], data)
+ images, annos = zip(*data)
+
+ move_data(images, f'images/{phase}',
+ lambda x: x.replace('/rgb', ''))
+ move_data(annos, f'annotations/{phase}',
+ lambda x: x.replace('/sync_depth', ''))
+
+ process_phase('train')
+ process_phase('test')
+
+
+def main():
+ args = parse_args()
+
+ print('Making directories...')
+ mkdir_or_exist(args.out_dir)
+ for subdir in [
+ 'images/train', 'images/test', 'annotations/train',
+ 'annotations/test'
+ ]:
+ mkdir_or_exist(osp.join(args.out_dir, subdir))
+
+ print('Generating images and annotations...')
+
+ if args.raw_data.endswith('.zip'):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ zip_file = zipfile.ZipFile(args.raw_data)
+ zip_file.extractall(tmp_dir)
+ reorganize(osp.join(tmp_dir, 'nyu'), args.out_dir)
+ else:
+ assert osp.isdir(
+ args.raw_data
+ ), 'the argument --raw-data should be either a zip file or directory.'
+ reorganize(args.raw_data, args.out_dir)
+
+ print('Done!')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/misc/publish_model.py b/tools/misc/publish_model.py
index c1bbc9ac1ad..e035ad90e85 100644
--- a/tools/misc/publish_model.py
+++ b/tools/misc/publish_model.py
@@ -1,9 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import subprocess
+from hashlib import sha256
import torch
+BLOCK_SIZE = 128 * 1024
+
def parse_args():
parser = argparse.ArgumentParser(
@@ -14,6 +17,17 @@ def parse_args():
return args
+def sha256sum(filename: str) -> str:
+ """Compute SHA256 message digest from a file."""
+ hash_func = sha256()
+ byte_array = bytearray(BLOCK_SIZE)
+ memory_view = memoryview(byte_array)
+ with open(filename, 'rb', buffering=0) as file:
+ for block in iter(lambda: file.readinto(memory_view), 0):
+ hash_func.update(memory_view[:block])
+ return hash_func.hexdigest()
+
+
def process_checkpoint(in_file, out_file):
checkpoint = torch.load(in_file, map_location='cpu')
# remove optimizer for smaller file size
@@ -22,7 +36,7 @@ def process_checkpoint(in_file, out_file):
# if it is necessary to remove some sensitive data in checkpoint['meta'],
# add the code here.
torch.save(checkpoint, out_file)
- sha = subprocess.check_output(['sha256sum', out_file]).decode()
+ sha = sha256sum(in_file)
final_file = out_file.rstrip('.pth') + f'-{sha[:8]}.pth'
subprocess.Popen(['mv', out_file, final_file])
diff --git a/tools/model_converters/clip2mmseg.py b/tools/model_converters/clip2mmseg.py
new file mode 100644
index 00000000000..9a97e4b04ab
--- /dev/null
+++ b/tools/model_converters/clip2mmseg.py
@@ -0,0 +1,163 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import os.path as osp
+from collections import OrderedDict
+
+import mmengine
+import torch
+from mmengine.runner import CheckpointLoader
+
+
+def convert_vitlayer(paras):
+ new_para_name = ''
+ if paras[0] == 'ln_1':
+ new_para_name = '.'.join(['ln1'] + paras[1:])
+ elif paras[0] == 'attn':
+ new_para_name = '.'.join(['attn.attn'] + paras[1:])
+ elif paras[0] == 'ln_2':
+ new_para_name = '.'.join(['ln2'] + paras[1:])
+ elif paras[0] == 'mlp':
+ if paras[1] == 'c_fc':
+ new_para_name = '.'.join(['ffn.layers.0.0'] + paras[-1:])
+ else:
+ new_para_name = '.'.join(['ffn.layers.1'] + paras[-1:])
+ else:
+ print(f'Wrong for {paras}')
+ return new_para_name
+
+
+def convert_translayer(paras):
+ new_para_name = ''
+ if paras[0] == 'attn':
+ new_para_name = '.'.join(['attentions.0.attn'] + paras[1:])
+ elif paras[0] == 'ln_1':
+ new_para_name = '.'.join(['norms.0'] + paras[1:])
+ elif paras[0] == 'ln_2':
+ new_para_name = '.'.join(['norms.1'] + paras[1:])
+ elif paras[0] == 'mlp':
+ if paras[1] == 'c_fc':
+ new_para_name = '.'.join(['ffns.0.layers.0.0'] + paras[2:])
+ elif paras[1] == 'c_proj':
+ new_para_name = '.'.join(['ffns.0.layers.1'] + paras[2:])
+ else:
+ print(f'Wrong for {paras}')
+ else:
+ print(f'Wrong for {paras}')
+ return new_para_name
+
+
+def convert_key_name(ckpt, visual_split):
+ new_ckpt = OrderedDict()
+ for k, v in ckpt.items():
+ key_list = k.split('.')
+ if key_list[0] == 'visual':
+ new_transform_name = 'image_encoder'
+ if key_list[1] == 'class_embedding':
+ new_name = '.'.join([new_transform_name, 'cls_token'])
+ elif key_list[1] == 'positional_embedding':
+ new_name = '.'.join([new_transform_name, 'pos_embed'])
+ elif key_list[1] == 'conv1':
+ new_name = '.'.join([
+ new_transform_name, 'patch_embed.projection', key_list[2]
+ ])
+ elif key_list[1] == 'ln_pre':
+ new_name = '.'.join(
+ [new_transform_name, key_list[1], key_list[2]])
+ elif key_list[1] == 'transformer':
+ new_layer_name = 'layers'
+ layer_index = key_list[3]
+ paras = key_list[4:]
+ if int(layer_index) < visual_split:
+ new_para_name = convert_vitlayer(paras)
+ new_name = '.'.join([
+ new_transform_name, new_layer_name, layer_index,
+ new_para_name
+ ])
+ else:
+ new_para_name = convert_translayer(paras)
+ new_transform_name = 'decode_head.rec_with_attnbias'
+ new_layer_name = 'layers'
+ layer_index = str(int(layer_index) - visual_split)
+ new_name = '.'.join([
+ new_transform_name, new_layer_name, layer_index,
+ new_para_name
+ ])
+ elif key_list[1] == 'proj':
+ new_name = 'decode_head.rec_with_attnbias.proj.weight'
+ elif key_list[1] == 'ln_post':
+ new_name = k.replace('visual', 'decode_head.rec_with_attnbias')
+ else:
+ print(f'pop parameter: {k}')
+ continue
+ else:
+ text_encoder_name = 'text_encoder'
+ if key_list[0] == 'transformer':
+ layer_name = 'transformer'
+ layer_index = key_list[2]
+ paras = key_list[3:]
+ new_para_name = convert_translayer(paras)
+ new_name = '.'.join([
+ text_encoder_name, layer_name, layer_index, new_para_name
+ ])
+ elif key_list[0] in [
+ 'positional_embedding', 'text_projection', 'bg_embed',
+ 'attn_mask', 'logit_scale', 'token_embedding', 'ln_final'
+ ]:
+ new_name = 'text_encoder.' + k
+ else:
+ print(f'pop parameter: {k}')
+ continue
+ new_ckpt[new_name] = v
+
+ return new_ckpt
+
+
+def convert_tensor(ckpt):
+ cls_token = ckpt['image_encoder.cls_token']
+ new_cls_token = cls_token.unsqueeze(0).unsqueeze(0)
+ ckpt['image_encoder.cls_token'] = new_cls_token
+ pos_embed = ckpt['image_encoder.pos_embed']
+ new_pos_embed = pos_embed.unsqueeze(0)
+ ckpt['image_encoder.pos_embed'] = new_pos_embed
+ proj_weight = ckpt['decode_head.rec_with_attnbias.proj.weight']
+ new_proj_weight = proj_weight.transpose(1, 0)
+ ckpt['decode_head.rec_with_attnbias.proj.weight'] = new_proj_weight
+ return ckpt
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Convert keys in timm pretrained vit models to '
+ 'MMSegmentation style.')
+ parser.add_argument('src', help='src model path or url')
+ # The dst path must be a full path of the new checkpoint.
+ parser.add_argument('dst', help='save path')
+ args = parser.parse_args()
+
+ if any([s in args.src for s in ['B-16', 'b16', 'base_patch16']]):
+ visual_split = 9
+ elif any([s in args.src for s in ['L-14', 'l14', 'large_patch14']]):
+ visual_split = 18
+ else:
+ print('Make sure the clip model is ViT-B/16 or ViT-L/14!')
+ visual_split = -1
+ checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
+ if isinstance(checkpoint, torch.jit.RecursiveScriptModule):
+ state_dict = checkpoint.state_dict()
+ else:
+ if 'state_dict' in checkpoint:
+ # timm checkpoint
+ state_dict = checkpoint['state_dict']
+ elif 'model' in checkpoint:
+ # deit checkpoint
+ state_dict = checkpoint['model']
+ else:
+ state_dict = checkpoint
+ weight = convert_key_name(state_dict, visual_split)
+ weight = convert_tensor(weight)
+ mmengine.mkdir_or_exist(osp.dirname(args.dst))
+ torch.save(weight, args.dst)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/model_converters/san2mmseg.py b/tools/model_converters/san2mmseg.py
new file mode 100644
index 00000000000..301a46608e0
--- /dev/null
+++ b/tools/model_converters/san2mmseg.py
@@ -0,0 +1,220 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import os.path as osp
+from collections import OrderedDict
+
+import mmengine
+import torch
+from mmengine.runner import CheckpointLoader
+
+
+def convert_key_name(ckpt):
+ new_ckpt = OrderedDict()
+
+ for k, v in ckpt.items():
+ key_list = k.split('.')
+ if key_list[0] == 'clip_visual_extractor':
+ new_transform_name = 'image_encoder'
+ if key_list[1] == 'class_embedding':
+ new_name = '.'.join([new_transform_name, 'cls_token'])
+ elif key_list[1] == 'positional_embedding':
+ new_name = '.'.join([new_transform_name, 'pos_embed'])
+ elif key_list[1] == 'conv1':
+ new_name = '.'.join([
+ new_transform_name, 'patch_embed.projection', key_list[2]
+ ])
+ elif key_list[1] == 'ln_pre':
+ new_name = '.'.join(
+ [new_transform_name, key_list[1], key_list[2]])
+ elif key_list[1] == 'resblocks':
+ new_layer_name = 'layers'
+ layer_index = key_list[2]
+ paras = key_list[3:]
+ if paras[0] == 'ln_1':
+ new_para_name = '.'.join(['ln1'] + key_list[4:])
+ elif paras[0] == 'attn':
+ new_para_name = '.'.join(['attn.attn'] + key_list[4:])
+ elif paras[0] == 'ln_2':
+ new_para_name = '.'.join(['ln2'] + key_list[4:])
+ elif paras[0] == 'mlp':
+ if paras[1] == 'c_fc':
+ new_para_name = '.'.join(['ffn.layers.0.0'] +
+ key_list[-1:])
+ else:
+ new_para_name = '.'.join(['ffn.layers.1'] +
+ key_list[-1:])
+ new_name = '.'.join([
+ new_transform_name, new_layer_name, layer_index,
+ new_para_name
+ ])
+ elif key_list[0] == 'side_adapter_network':
+ decode_head_name = 'decode_head'
+ module_name = 'side_adapter_network'
+ if key_list[1] == 'vit_model':
+ if key_list[2] == 'blocks':
+ layer_name = 'encode_layers'
+ layer_index = key_list[3]
+ paras = key_list[4:]
+ if paras[0] == 'norm1':
+ new_para_name = '.'.join(['ln1'] + key_list[5:])
+ elif paras[0] == 'attn':
+ new_para_name = '.'.join(key_list[4:])
+ new_para_name = new_para_name.replace(
+ 'attn.qkv.', 'attn.attn.in_proj_')
+ new_para_name = new_para_name.replace(
+ 'attn.proj', 'attn.attn.out_proj')
+ elif paras[0] == 'norm2':
+ new_para_name = '.'.join(['ln2'] + key_list[5:])
+ elif paras[0] == 'mlp':
+ new_para_name = '.'.join(['ffn'] + key_list[5:])
+ new_para_name = new_para_name.replace(
+ 'fc1', 'layers.0.0')
+ new_para_name = new_para_name.replace(
+ 'fc2', 'layers.1')
+ else:
+ print(f'Wrong for {k}')
+ new_name = '.'.join([
+ decode_head_name, module_name, layer_name, layer_index,
+ new_para_name
+ ])
+ elif key_list[2] == 'pos_embed':
+ new_name = '.'.join(
+ [decode_head_name, module_name, 'pos_embed'])
+ elif key_list[2] == 'patch_embed':
+ new_name = '.'.join([
+ decode_head_name, module_name, 'patch_embed',
+ 'projection', key_list[4]
+ ])
+ else:
+ print(f'Wrong for {k}')
+ elif key_list[1] == 'query_embed' or key_list[
+ 1] == 'query_pos_embed':
+ new_name = '.'.join(
+ [decode_head_name, module_name, key_list[1]])
+ elif key_list[1] == 'fusion_layers':
+ layer_name = 'conv_clips'
+ layer_index = key_list[2][-1]
+ paras = '.'.join(key_list[3:])
+ new_para_name = paras.replace('input_proj.0', '0')
+ new_para_name = new_para_name.replace('input_proj.1', '1.conv')
+ new_name = '.'.join([
+ decode_head_name, module_name, layer_name, layer_index,
+ new_para_name
+ ])
+ elif key_list[1] == 'mask_decoder':
+ new_name = 'decode_head.' + k
+ else:
+ print(f'Wrong for {k}')
+ elif key_list[0] == 'clip_rec_head':
+ module_name = 'rec_with_attnbias'
+ if key_list[1] == 'proj':
+ new_name = '.'.join(
+ [decode_head_name, module_name, 'proj.weight'])
+ elif key_list[1] == 'ln_post':
+ new_name = '.'.join(
+ [decode_head_name, module_name, 'ln_post', key_list[2]])
+ elif key_list[1] == 'resblocks':
+ new_layer_name = 'layers'
+ layer_index = key_list[2]
+ paras = key_list[3:]
+ if paras[0] == 'ln_1':
+ new_para_name = '.'.join(['norms.0'] + paras[1:])
+ elif paras[0] == 'attn':
+ new_para_name = '.'.join(['attentions.0.attn'] + paras[1:])
+ elif paras[0] == 'ln_2':
+ new_para_name = '.'.join(['norms.1'] + paras[1:])
+ elif paras[0] == 'mlp':
+ if paras[1] == 'c_fc':
+ new_para_name = '.'.join(['ffns.0.layers.0.0'] +
+ paras[2:])
+ elif paras[1] == 'c_proj':
+ new_para_name = '.'.join(['ffns.0.layers.1'] +
+ paras[2:])
+ else:
+ print(f'Wrong for {k}')
+ new_name = '.'.join([
+ decode_head_name, module_name, new_layer_name, layer_index,
+ new_para_name
+ ])
+ else:
+ print(f'Wrong for {k}')
+ elif key_list[0] == 'ov_classifier':
+ text_encoder_name = 'text_encoder'
+ if key_list[1] == 'transformer':
+ layer_name = 'transformer'
+ layer_index = key_list[3]
+ paras = key_list[4:]
+ if paras[0] == 'attn':
+ new_para_name = '.'.join(['attentions.0.attn'] + paras[1:])
+ elif paras[0] == 'ln_1':
+ new_para_name = '.'.join(['norms.0'] + paras[1:])
+ elif paras[0] == 'ln_2':
+ new_para_name = '.'.join(['norms.1'] + paras[1:])
+ elif paras[0] == 'mlp':
+ if paras[1] == 'c_fc':
+ new_para_name = '.'.join(['ffns.0.layers.0.0'] +
+ paras[2:])
+ elif paras[1] == 'c_proj':
+ new_para_name = '.'.join(['ffns.0.layers.1'] +
+ paras[2:])
+ else:
+ print(f'Wrong for {k}')
+ else:
+ print(f'Wrong for {k}')
+ new_name = '.'.join([
+ text_encoder_name, layer_name, layer_index, new_para_name
+ ])
+ elif key_list[1] in [
+ 'positional_embedding', 'text_projection', 'bg_embed',
+ 'attn_mask', 'logit_scale', 'token_embedding', 'ln_final'
+ ]:
+ new_name = k.replace('ov_classifier', 'text_encoder')
+ else:
+ print(f'Wrong for {k}')
+ elif key_list[0] == 'criterion':
+ new_name = k
+ else:
+ print(f'Wrong for {k}')
+ new_ckpt[new_name] = v
+ return new_ckpt
+
+
+def convert_tensor(ckpt):
+ cls_token = ckpt['image_encoder.cls_token']
+ new_cls_token = cls_token.unsqueeze(0).unsqueeze(0)
+ ckpt['image_encoder.cls_token'] = new_cls_token
+ pos_embed = ckpt['image_encoder.pos_embed']
+ new_pos_embed = pos_embed.unsqueeze(0)
+ ckpt['image_encoder.pos_embed'] = new_pos_embed
+ proj_weight = ckpt['decode_head.rec_with_attnbias.proj.weight']
+ new_proj_weight = proj_weight.transpose(1, 0)
+ ckpt['decode_head.rec_with_attnbias.proj.weight'] = new_proj_weight
+ return ckpt
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Convert keys in timm pretrained vit models to '
+ 'MMSegmentation style.')
+ parser.add_argument('src', help='src model path or url')
+ # The dst path must be a full path of the new checkpoint.
+ parser.add_argument('dst', help='save path')
+ args = parser.parse_args()
+
+ checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
+ if 'state_dict' in checkpoint:
+ # timm checkpoint
+ state_dict = checkpoint['state_dict']
+ elif 'model' in checkpoint:
+ # deit checkpoint
+ state_dict = checkpoint['model']
+ else:
+ state_dict = checkpoint
+ weight = convert_key_name(state_dict)
+ weight = convert_tensor(weight)
+ mmengine.mkdir_or_exist(osp.dirname(args.dst))
+ torch.save(weight, args.dst)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/test.py b/tools/test.py
index 64da2bc2616..0d7f39b3a8b 100644
--- a/tools/test.py
+++ b/tools/test.py
@@ -47,7 +47,10 @@ def parse_args():
help='job launcher')
parser.add_argument(
'--tta', action='store_true', help='Test time augmentation')
- parser.add_argument('--local_rank', type=int, default=0)
+ # When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
+ # will pass the `--local-rank` parameter to `tools/train.py` instead
+ # of `--local_rank`.
+ parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
@@ -65,8 +68,8 @@ def trigger_visualization_hook(cfg, args):
visualization_hook['show'] = True
visualization_hook['wait_time'] = args.wait_time
if args.show_dir:
- visulizer = cfg.visualizer
- visulizer['save_dir'] = args.show_dir
+ visualizer = cfg.visualizer
+ visualizer['save_dir'] = args.show_dir
else:
raise RuntimeError(
'VisualizationHook must be included in default_hooks.'
diff --git a/tools/train.py b/tools/train.py
index 17213066648..10fdaa1874b 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -40,7 +40,10 @@ def parse_args():
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
- parser.add_argument('--local_rank', type=int, default=0)
+ # When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
+ # will pass the `--local-rank` parameter to `tools/train.py` instead
+ # of `--local_rank`.
+ parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)