diff --git a/mmseg/engine/hooks/visualization_hook.py b/mmseg/engine/hooks/visualization_hook.py index ea238c6969..21cddde89d 100644 --- a/mmseg/engine/hooks/visualization_hook.py +++ b/mmseg/engine/hooks/visualization_hook.py @@ -4,7 +4,7 @@ from typing import Optional, Sequence import mmcv -import mmengine.fileio as fileio +from mmengine.fileio import get from mmengine.hooks import Hook from mmengine.runner import Runner from mmengine.visualization import Visualizer @@ -61,37 +61,69 @@ def __init__(self, 'hook for visualization will not take ' 'effect. The results will NOT be ' 'visualized or stored.') + self._test_index = 0 - def _after_iter(self, - runner: Runner, - batch_idx: int, - data_batch: dict, - outputs: Sequence[SegDataSample], - mode: str = 'val') -> None: + def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, + outputs: Sequence[SegDataSample]) -> None: """Run after every ``self.interval`` validation iterations. Args: runner (:obj:`Runner`): The runner of the validation process. batch_idx (int): The index of the current batch in the val loop. data_batch (dict): Data from dataloader. - outputs (Sequence[:obj:`SegDataSample`]): Outputs from model. - mode (str): mode (str): Current mode of runner. Defaults to 'val'. + outputs (Sequence[:obj:`SegDataSample`]]): A batch of data samples + that contain annotations and predictions. """ - if self.draw is False or mode == 'train': + if self.draw is False: return - if self.every_n_inner_iters(batch_idx, self.interval): - for output in outputs: - img_path = output.img_path - img_bytes = fileio.get( - img_path, backend_args=self.backend_args) - img = mmcv.imfrombytes(img_bytes, channel_order='rgb') - window_name = f'{mode}_{osp.basename(img_path)}' - - self._visualizer.add_datasample( - window_name, - img, - data_sample=output, - show=self.show, - wait_time=self.wait_time, - step=runner.iter) + # There is no guarantee that the same batch of images + # is visualized for each evaluation. + total_curr_iter = runner.iter + batch_idx + + # Visualize only the first data + img_path = outputs[0].img_path + img_bytes = get(img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + window_name = f'val_{osp.basename(img_path)}' + + if total_curr_iter % self.interval == 0: + self._visualizer.add_datasample( + window_name, + img, + data_sample=outputs[0], + show=self.show, + wait_time=self.wait_time, + step=total_curr_iter) + + def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, + outputs: Sequence[SegDataSample]) -> None: + """Run after every testing iterations. + + Args: + runner (:obj:`Runner`): The runner of the testing process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`SegDataSample`]): A batch of data samples + that contain annotations and predictions. + """ + if self.draw is False: + return + + for data_sample in outputs: + self._test_index += 1 + + img_path = data_sample.img_path + window_name = f'test_{osp.basename(img_path)}' + + img_path = data_sample.img_path + img_bytes = get(img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + + self._visualizer.add_datasample( + window_name, + img, + data_sample=data_sample, + show=self.show, + wait_time=self.wait_time, + step=self._test_index) diff --git a/tests/test_engine/test_visualization_hook.py b/tests/test_engine/test_visualization_hook.py index 274b0e547f..022e27c77c 100644 --- a/tests/test_engine/test_visualization_hook.py +++ b/tests/test_engine/test_visualization_hook.py @@ -58,6 +58,7 @@ def test_after_val_iter(self): def test_after_test_iter(self): runner = Mock() - runner.iter = 3 hook = SegVisualizationHook(draw=True, interval=1) + assert hook._test_index == 0 hook.after_test_iter(runner, 1, self.data_batch, self.outputs) + assert hook._test_index == len(self.outputs)