Skip to content

Commit

Permalink
[Fix] bugfix/avoid-runner-iter-in-vis-hook-test-mode (#3596)
Browse files Browse the repository at this point in the history
## Motivation

The current `SegVisualizationHook` implements the `_after_iter` method,
which is invoked during the validation and testing pipelines. However,
when in
[test_mode](https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/engine/hooks/visualization_hook.py#L97),
the implementation attempts to access `runner.iter`. This attribute is
defined in the [`mmengine`
codebase](https://github.com/open-mmlab/mmengine/blob/main/mmengine/runner/runner.py#L538)
and is designed to return `train_loop.iter`. Accessing this property
during testing can be problematic, particularly in scenarios where the
model is being evaluated post-training, without initiating a training
loop. This can lead to a crash if the implementation tries to build a
training dataset for which the annotation file is unavailable at the
time of evaluation. Thus, it is crucial to avoid relying on this
property in test mode.

## Modification

To resolve this issue, the proposal is to replace the `_after_iter`
method with `after_val_iter` and `after_test_iter` methods, modifying
their behavior accordingly. Specifically, when in testing mode, the
implementation should utilize a `test_index` counter instead of
accessing `runner.iter`. This adjustment will circumvent the issue of
accessing `train_loop.iter` during test mode, ensuring the process does
not attempt to access or build a training dataset, thereby preventing
potential crashes due to missing annotation files.
  • Loading branch information
mmeendez8 committed Mar 22, 2024
1 parent b677081 commit b040e14
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 26 deletions.
82 changes: 57 additions & 25 deletions mmseg/engine/hooks/visualization_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Optional, Sequence

import mmcv
import mmengine.fileio as fileio
from mmengine.fileio import get
from mmengine.hooks import Hook
from mmengine.runner import Runner
from mmengine.visualization import Visualizer
Expand Down Expand Up @@ -61,37 +61,69 @@ def __init__(self,
'hook for visualization will not take '
'effect. The results will NOT be '
'visualized or stored.')
self._test_index = 0

def _after_iter(self,
runner: Runner,
batch_idx: int,
data_batch: dict,
outputs: Sequence[SegDataSample],
mode: str = 'val') -> None:
def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
outputs: Sequence[SegDataSample]) -> None:
"""Run after every ``self.interval`` validation iterations.
Args:
runner (:obj:`Runner`): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (dict): Data from dataloader.
outputs (Sequence[:obj:`SegDataSample`]): Outputs from model.
mode (str): mode (str): Current mode of runner. Defaults to 'val'.
outputs (Sequence[:obj:`SegDataSample`]]): A batch of data samples
that contain annotations and predictions.
"""
if self.draw is False or mode == 'train':
if self.draw is False:
return

if self.every_n_inner_iters(batch_idx, self.interval):
for output in outputs:
img_path = output.img_path
img_bytes = fileio.get(
img_path, backend_args=self.backend_args)
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
window_name = f'{mode}_{osp.basename(img_path)}'

self._visualizer.add_datasample(
window_name,
img,
data_sample=output,
show=self.show,
wait_time=self.wait_time,
step=runner.iter)
# There is no guarantee that the same batch of images
# is visualized for each evaluation.
total_curr_iter = runner.iter + batch_idx

# Visualize only the first data
img_path = outputs[0].img_path
img_bytes = get(img_path, backend_args=self.backend_args)
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
window_name = f'val_{osp.basename(img_path)}'

if total_curr_iter % self.interval == 0:
self._visualizer.add_datasample(
window_name,
img,
data_sample=outputs[0],
show=self.show,
wait_time=self.wait_time,
step=total_curr_iter)

def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
outputs: Sequence[SegDataSample]) -> None:
"""Run after every testing iterations.
Args:
runner (:obj:`Runner`): The runner of the testing process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (dict): Data from dataloader.
outputs (Sequence[:obj:`SegDataSample`]): A batch of data samples
that contain annotations and predictions.
"""
if self.draw is False:
return

for data_sample in outputs:
self._test_index += 1

img_path = data_sample.img_path
window_name = f'test_{osp.basename(img_path)}'

img_path = data_sample.img_path
img_bytes = get(img_path, backend_args=self.backend_args)
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')

self._visualizer.add_datasample(
window_name,
img,
data_sample=data_sample,
show=self.show,
wait_time=self.wait_time,
step=self._test_index)
3 changes: 2 additions & 1 deletion tests/test_engine/test_visualization_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def test_after_val_iter(self):

def test_after_test_iter(self):
runner = Mock()
runner.iter = 3
hook = SegVisualizationHook(draw=True, interval=1)
assert hook._test_index == 0
hook.after_test_iter(runner, 1, self.data_batch, self.outputs)
assert hook._test_index == len(self.outputs)

0 comments on commit b040e14

Please sign in to comment.