Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] Visualizer compatible with MultiTaskDataSample #1702

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 65 additions & 28 deletions mmpretrain/visualization/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from mmengine.visualization.utils import img_from_canvas

from mmpretrain.registry import VISUALIZERS
from mmpretrain.structures import DataSample
from mmpretrain.structures import DataSample, MultiTaskDataSample
from .utils import create_figure, get_adaptive_scale


Expand Down Expand Up @@ -99,6 +99,67 @@ def visualize_cls(self,
Returns:
np.ndarray: The visualization image.
"""

def _draw_gt(data_sample: DataSample,
classes: Optional[Sequence[str]],
draw_gt: bool,
texts: Sequence[str],
parent_task: str = ''):
if isinstance(data_sample, MultiTaskDataSample):
for task in data_sample.tasks:
sub_task = f'{parent_task}_{task}' if parent_task else task
_draw_gt(
data_sample.get(task), classes, draw_gt, texts,
sub_task)
else:
if draw_gt and 'gt_label' in data_sample:
idx = data_sample.gt_label.tolist()
class_labels = [''] * len(idx)
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
labels = [
str(idx[i]) + class_labels[i] for i in range(len(idx))
]
prefix = f'{parent_task} Ground truth: ' if parent_task \
else 'Ground truth: '
texts.append(prefix +
('\n' + ' ' * len(prefix)).join(labels))

def _draw_pred(data_sample: DataSample,
classes: Optional[Sequence[str]],
draw_pred: bool,
draw_score: bool,
texts: Sequence[str],
parent_task: str = ''):
if isinstance(data_sample, MultiTaskDataSample):
for task in data_sample.tasks:
sub_task = f'{parent_task}_{task}' if parent_task else task
_draw_pred(
data_sample.get(task), classes, draw_pred, draw_score,
texts, sub_task)
else:
if draw_pred and 'pred_label' in data_sample:
idx = data_sample.pred_label.tolist()
score_labels = [''] * len(idx)
class_labels = [''] * len(idx)
if draw_score and 'pred_score' in data_sample:
score_labels = [
f', {data_sample.pred_score[i].item():.2f}'
for i in idx
]

if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]

labels = [
str(idx[i]) + score_labels[i] + class_labels[i]
for i in range(len(idx))
]
prefix = f'{parent_task} Prediction: ' if parent_task \
else 'Prediction: '
texts.append(prefix +
('\n' + ' ' * len(prefix)).join(labels))

if self.dataset_meta is not None:
classes = classes or self.dataset_meta.get('classes', None)

Expand All @@ -114,33 +175,9 @@ def visualize_cls(self,
texts = []
self.set_image(image)

if draw_gt and 'gt_label' in data_sample:
idx = data_sample.gt_label.tolist()
class_labels = [''] * len(idx)
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
labels = [str(idx[i]) + class_labels[i] for i in range(len(idx))]
prefix = 'Ground truth: '
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))

if draw_pred and 'pred_label' in data_sample:
idx = data_sample.pred_label.tolist()
score_labels = [''] * len(idx)
class_labels = [''] * len(idx)
if draw_score and 'pred_score' in data_sample:
score_labels = [
f', {data_sample.pred_score[i].item():.2f}' for i in idx
]

if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]

labels = [
str(idx[i]) + score_labels[i] + class_labels[i]
for i in range(len(idx))
]
prefix = 'Prediction: '
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
_draw_gt(data_sample, classes, draw_gt, texts)

_draw_pred(data_sample, classes, draw_pred, draw_score, texts)

img_scale = get_adaptive_scale(image.shape[:2])
text_cfg = {
Expand Down
42 changes: 41 additions & 1 deletion tests/test_visualization/test_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import torch

from mmpretrain.structures import DataSample
from mmpretrain.structures import DataSample, MultiTaskDataSample
from mmpretrain.visualization import UniversalVisualizer


Expand Down Expand Up @@ -123,6 +123,46 @@ def draw_texts(text, font_sizes, *_, **__):
data_sample,
rescale_factor=2.)

def test_visualize_multitask_cls(self):
image = np.ones((1000, 1000, 3), np.uint8)
gt_label = {'task0': {'task00': 2, 'task01': 1}, 'task1': 1}
data_sample = MultiTaskDataSample()
task_sample = DataSample().set_gt_label(
gt_label['task1']).set_pred_label(1).set_pred_score(
torch.tensor([0.1, 0.8, 0.1]))
data_sample.set_field(task_sample, 'task1')
data_sample.set_field(MultiTaskDataSample(), 'task0')
for task_name in gt_label['task0']:
task_sample = DataSample().set_gt_label(
gt_label['task0'][task_name]).set_pred_label(2).set_pred_score(
torch.tensor([0.1, 0.4, 0.5]))
data_sample.task0.set_field(task_sample, task_name)

# Test show
def mock_show(drawn_img, win_name, wait_time):
self.assertFalse((image == drawn_img).all())
self.assertEqual(win_name, 'test_cls')
self.assertEqual(wait_time, 0)

with patch.object(self.vis, 'show', mock_show):
self.vis.visualize_cls(
image=image,
data_sample=data_sample,
show=True,
name='test_cls',
step=2)

# Test storage backend.
save_file = osp.join(self.tmpdir.name,
'vis_data/vis_image/test_cls_2.png')
self.assertTrue(osp.exists(save_file))

# Test out_file
out_file = osp.join(self.tmpdir.name, 'results_2.png')
self.vis.visualize_cls(
image=image, data_sample=data_sample, out_file=out_file)
self.assertTrue(osp.exists(out_file))

def test_visualize_image_retrieval(self):
image = np.ones((10, 10, 3), np.uint8)
data_sample = DataSample().set_pred_score([0.1, 0.8, 0.1])
Expand Down