Skip to content

Commit

Permalink
[Fix] Fix inferencer (#3333)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiexinch committed Sep 18, 2023
1 parent 913fe3e commit f1fa61a
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions mmseg/apis/mmseg_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ class MMSegInferencer(BaseInferencer):

preprocess_kwargs: set = set()
forward_kwargs: set = {'mode', 'out_dir'}
visualize_kwargs: set = {'show', 'wait_time', 'img_out_dir', 'opacity'}
visualize_kwargs: set = {
'show', 'wait_time', 'img_out_dir', 'opacity', 'return_vis'
}
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}

def __init__(self,
Expand Down Expand Up @@ -137,6 +139,7 @@ def __call__(self,
inputs: InputsType,
return_datasamples: bool = False,
batch_size: int = 1,
return_vis: bool = False,
show: bool = False,
wait_time: int = 0,
out_dir: str = '',
Expand Down Expand Up @@ -188,11 +191,13 @@ def __call__(self,
wait_time=wait_time,
img_out_dir=img_out_dir,
pred_out_dir=pred_out_dir,
return_vis=return_vis,
**kwargs)

def visualize(self,
inputs: list,
preds: List[dict],
return_vis: bool = False,
show: bool = False,
wait_time: int = 0,
img_out_dir: str = '',
Expand All @@ -213,12 +218,12 @@ def visualize(self,
Returns:
List[np.ndarray]: Visualization results.
"""
if self.visualizer is None or (not show and img_out_dir == ''):
if not show and img_out_dir == '' and not return_vis:
return None

if getattr(self, 'visualizer') is None:
if self.visualizer is None:
raise ValueError('Visualization needs the "visualizer" term'
'defined in the config, but got None')
'defined in the config, but got None.')

self.visualizer.set_dataset_meta(**self.model.dataset_meta)
self.visualizer.alpha = opacity

Expand Down Expand Up @@ -250,10 +255,11 @@ def visualize(self,
draw_gt=False,
draw_pred=True,
out_file=out_file)
results.append(self.visualizer.get_image())
if return_vis:
results.append(self.visualizer.get_image())
self.num_visualized_imgs += 1

return results
return results if return_vis else None

def postprocess(self,
preds: PredType,
Expand Down

0 comments on commit f1fa61a

Please sign in to comment.