From b040e147adfa027bbc071b624bedf0ae84dfc922 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20M=C3=A9ndez?= Date: Fri, 22 Mar 2024 11:04:17 +0100 Subject: [PATCH] [Fix] bugfix/avoid-runner-iter-in-vis-hook-test-mode (#3596) ## Motivation The current `SegVisualizationHook` implements the `_after_iter` method, which is invoked during the validation and testing pipelines. However, when in [test_mode](https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/engine/hooks/visualization_hook.py#L97), the implementation attempts to access `runner.iter`. This attribute is defined in the [`mmengine` codebase](https://github.com/open-mmlab/mmengine/blob/main/mmengine/runner/runner.py#L538) and is designed to return `train_loop.iter`. Accessing this property during testing can be problematic, particularly in scenarios where the model is being evaluated post-training, without initiating a training loop. This can lead to a crash if the implementation tries to build a training dataset for which the annotation file is unavailable at the time of evaluation. Thus, it is crucial to avoid relying on this property in test mode. ## Modification To resolve this issue, the proposal is to replace the `_after_iter` method with `after_val_iter` and `after_test_iter` methods, modifying their behavior accordingly. Specifically, when in testing mode, the implementation should utilize a `test_index` counter instead of accessing `runner.iter`. This adjustment will circumvent the issue of accessing `train_loop.iter` during test mode, ensuring the process does not attempt to access or build a training dataset, thereby preventing potential crashes due to missing annotation files. --- mmseg/engine/hooks/visualization_hook.py | 82 ++++++++++++++------ tests/test_engine/test_visualization_hook.py | 3 +- 2 files changed, 59 insertions(+), 26 deletions(-) diff --git a/mmseg/engine/hooks/visualization_hook.py b/mmseg/engine/hooks/visualization_hook.py index ea238c6969..21cddde89d 100644 --- a/mmseg/engine/hooks/visualization_hook.py +++ b/mmseg/engine/hooks/visualization_hook.py @@ -4,7 +4,7 @@ from typing import Optional, Sequence import mmcv -import mmengine.fileio as fileio +from mmengine.fileio import get from mmengine.hooks import Hook from mmengine.runner import Runner from mmengine.visualization import Visualizer @@ -61,37 +61,69 @@ def __init__(self, 'hook for visualization will not take ' 'effect. The results will NOT be ' 'visualized or stored.') + self._test_index = 0 - def _after_iter(self, - runner: Runner, - batch_idx: int, - data_batch: dict, - outputs: Sequence[SegDataSample], - mode: str = 'val') -> None: + def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, + outputs: Sequence[SegDataSample]) -> None: """Run after every ``self.interval`` validation iterations. Args: runner (:obj:`Runner`): The runner of the validation process. batch_idx (int): The index of the current batch in the val loop. data_batch (dict): Data from dataloader. - outputs (Sequence[:obj:`SegDataSample`]): Outputs from model. - mode (str): mode (str): Current mode of runner. Defaults to 'val'. + outputs (Sequence[:obj:`SegDataSample`]]): A batch of data samples + that contain annotations and predictions. """ - if self.draw is False or mode == 'train': + if self.draw is False: return - if self.every_n_inner_iters(batch_idx, self.interval): - for output in outputs: - img_path = output.img_path - img_bytes = fileio.get( - img_path, backend_args=self.backend_args) - img = mmcv.imfrombytes(img_bytes, channel_order='rgb') - window_name = f'{mode}_{osp.basename(img_path)}' - - self._visualizer.add_datasample( - window_name, - img, - data_sample=output, - show=self.show, - wait_time=self.wait_time, - step=runner.iter) + # There is no guarantee that the same batch of images + # is visualized for each evaluation. + total_curr_iter = runner.iter + batch_idx + + # Visualize only the first data + img_path = outputs[0].img_path + img_bytes = get(img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + window_name = f'val_{osp.basename(img_path)}' + + if total_curr_iter % self.interval == 0: + self._visualizer.add_datasample( + window_name, + img, + data_sample=outputs[0], + show=self.show, + wait_time=self.wait_time, + step=total_curr_iter) + + def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, + outputs: Sequence[SegDataSample]) -> None: + """Run after every testing iterations. + + Args: + runner (:obj:`Runner`): The runner of the testing process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`SegDataSample`]): A batch of data samples + that contain annotations and predictions. + """ + if self.draw is False: + return + + for data_sample in outputs: + self._test_index += 1 + + img_path = data_sample.img_path + window_name = f'test_{osp.basename(img_path)}' + + img_path = data_sample.img_path + img_bytes = get(img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + + self._visualizer.add_datasample( + window_name, + img, + data_sample=data_sample, + show=self.show, + wait_time=self.wait_time, + step=self._test_index) diff --git a/tests/test_engine/test_visualization_hook.py b/tests/test_engine/test_visualization_hook.py index 274b0e547f..022e27c77c 100644 --- a/tests/test_engine/test_visualization_hook.py +++ b/tests/test_engine/test_visualization_hook.py @@ -58,6 +58,7 @@ def test_after_val_iter(self): def test_after_test_iter(self): runner = Mock() - runner.iter = 3 hook = SegVisualizationHook(draw=True, interval=1) + assert hook._test_index == 0 hook.after_test_iter(runner, 1, self.data_batch, self.outputs) + assert hook._test_index == len(self.outputs)