Skip to content

Commit

Permalink
[Fix] Update confusion_matrix.py (#3291)
Browse files Browse the repository at this point in the history
## Motivation



## Modification

The confusion_matrix.py is not compatible with the current version of
mmseg.

---------

Co-authored-by: xiexinch <[email protected]>
  • Loading branch information
XDWang97 and xiexinch committed Aug 31, 2023
1 parent 72e20a8 commit ebd5695
Showing 1 changed file with 39 additions and 26 deletions.
65 changes: 39 additions & 26 deletions tools/analysis_tools/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,22 @@
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator
from mmengine import Config, DictAction
from mmengine.utils import ProgressBar, load
from mmengine.config import Config, DictAction
from mmengine.registry import init_default_scope
from mmengine.utils import mkdir_or_exist, progressbar
from PIL import Image

from mmseg.datasets import build_dataset
from mmseg.registry import DATASETS

init_default_scope('mmseg')


def parse_args():
parser = argparse.ArgumentParser(
description='Generate confusion matrix from segmentation results')
parser.add_argument('config', help='test config file path')
parser.add_argument(
'prediction_path', help='prediction path where test .pkl result')
'prediction_path', help='prediction path where test folder result')
parser.add_argument(
'save_dir', help='directory where confusion matrix will be saved')
parser.add_argument(
Expand Down Expand Up @@ -50,15 +54,23 @@ def calculate_confusion_matrix(dataset, results):
dataset (Dataset): Test or val dataset.
results (list[ndarray]): A list of segmentation results in each image.
"""
n = len(dataset.CLASSES)
n = len(dataset.METAINFO['classes'])
confusion_matrix = np.zeros(shape=[n, n])
assert len(dataset) == len(results)
prog_bar = ProgressBar(len(results))
ignore_index = dataset.ignore_index
reduce_zero_label = dataset.reduce_zero_label
prog_bar = progressbar.ProgressBar(len(results))
for idx, per_img_res in enumerate(results):
res_segm = per_img_res
gt_segm = dataset.get_gt_seg_map_by_idx(idx)
gt_segm = dataset[idx]['data_samples'] \
.gt_sem_seg.data.squeeze().numpy().astype(np.uint8)
gt_segm, res_segm = gt_segm.flatten(), res_segm.flatten()
if reduce_zero_label:
gt_segm = gt_segm - 1
to_ignore = gt_segm == ignore_index

gt_segm, res_segm = gt_segm[~to_ignore], res_segm[~to_ignore]
inds = n * gt_segm + res_segm
inds = inds.flatten()
mat = np.bincount(inds, minlength=n**2).reshape(n, n)
confusion_matrix += mat
prog_bar.update()
Expand All @@ -70,7 +82,7 @@ def plot_confusion_matrix(confusion_matrix,
save_dir=None,
show=True,
title='Normalized Confusion Matrix',
color_theme='winter'):
color_theme='OrRd'):
"""Draw confusion matrix with matplotlib.
Args:
Expand All @@ -89,14 +101,15 @@ def plot_confusion_matrix(confusion_matrix,

num_classes = len(labels)
fig, ax = plt.subplots(
figsize=(2 * num_classes, 2 * num_classes * 0.8), dpi=180)
figsize=(2 * num_classes, 2 * num_classes * 0.8), dpi=300)
cmap = plt.get_cmap(color_theme)
im = ax.imshow(confusion_matrix, cmap=cmap)
plt.colorbar(mappable=im, ax=ax)
colorbar = plt.colorbar(mappable=im, ax=ax)
colorbar.ax.tick_params(labelsize=20) # 设置 colorbar 标签的字体大小

title_font = {'weight': 'bold', 'size': 12}
title_font = {'weight': 'bold', 'size': 20}
ax.set_title(title, fontdict=title_font)
label_font = {'size': 10}
label_font = {'size': 40}
plt.ylabel('Ground Truth Label', fontdict=label_font)
plt.xlabel('Prediction Label', fontdict=label_font)

Expand All @@ -116,8 +129,8 @@ def plot_confusion_matrix(confusion_matrix,
# draw label
ax.set_xticks(np.arange(num_classes))
ax.set_yticks(np.arange(num_classes))
ax.set_xticklabels(labels)
ax.set_yticklabels(labels)
ax.set_xticklabels(labels, fontsize=20)
ax.set_yticklabels(labels, fontsize=20)

ax.tick_params(
axis='x', bottom=False, top=True, labelbottom=False, labeltop=True)
Expand All @@ -135,13 +148,14 @@ def plot_confusion_matrix(confusion_matrix,
) if not np.isnan(confusion_matrix[i, j]) else -1),
ha='center',
va='center',
color='w',
size=7)
color='k',
size=20)

ax.set_ylim(len(confusion_matrix) - 0.5, -0.5) # matplotlib>3.1.1

fig.tight_layout()
if save_dir is not None:
mkdir_or_exist(save_dir)
plt.savefig(
os.path.join(save_dir, 'confusion_matrix.png'), format='png')
if show:
Expand All @@ -155,25 +169,24 @@ def main():
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

results = load(args.prediction_path)
results = []
for img in sorted(os.listdir(args.prediction_path)):
img = os.path.join(args.prediction_path, img)
image = Image.open(img)
image = np.copy(image)
results.append(image)

assert isinstance(results, list)
if isinstance(results[0], np.ndarray):
pass
else:
raise TypeError('invalid type of prediction results')

if isinstance(cfg.data.test, dict):
cfg.data.test.test_mode = True
elif isinstance(cfg.data.test, list):
for ds_cfg in cfg.data.test:
ds_cfg.test_mode = True

dataset = build_dataset(cfg.data.test)
dataset = DATASETS.build(cfg.test_dataloader.dataset)
confusion_matrix = calculate_confusion_matrix(dataset, results)
plot_confusion_matrix(
confusion_matrix,
dataset.CLASSES,
dataset.METAINFO['classes'],
save_dir=args.save_dir,
show=args.show,
title=args.title,
Expand Down

0 comments on commit ebd5695

Please sign in to comment.