Skip to content

Commit

Permalink
[Fix] Fix visualizor (#3154)
Browse files Browse the repository at this point in the history
## Motivation

**Current visualize result**


![rs-dev](https://github.com/open-mmlab/mmsegmentation/assets/15952744/147ea3f7-f632-457b-b257-031199320825)

**Fixed the visualization result**



![rs-fix](https://github.com/open-mmlab/mmsegmentation/assets/15952744/98a86025-5a1e-4c2b-83e0-653dd659ba79)


## Modification

remove mmengine `draw_binary_masks` api
  • Loading branch information
xiexinch committed Jul 3, 2023
1 parent cc74c5c commit 8806b4e
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 13 deletions.
9 changes: 4 additions & 5 deletions demo/inference_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
"outputs": [],
"source": [
"import torch\n",
"import mmcv\n",
"import matplotlib.pyplot as plt\n",
"from mmengine.model.utils import revert_sync_batchnorm\n",
"from mmseg.apis import init_model, inference_model, show_result_pyplot"
Expand All @@ -48,7 +47,7 @@
"outputs": [],
"source": [
"# build the model from a config file and a checkpoint file\n",
"model = init_model(config_file, checkpoint_file, device='cuda:0')"
"model = init_model(config_file, checkpoint_file, device='cpu')"
]
},
{
Expand All @@ -71,8 +70,8 @@
"outputs": [],
"source": [
"# show the results\n",
"vis_result = show_result_pyplot(model, img, result)\n",
"plt.imshow(mmcv.bgr2rgb(vis_result))"
"vis_result = show_result_pyplot(model, img, result, show=False)\n",
"plt.imshow(vis_result)"
]
},
{
Expand All @@ -99,7 +98,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.10.11"
},
"pycharm": {
"stem_cell": {
Expand Down
2 changes: 1 addition & 1 deletion mmseg/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def show_result_pyplot(model: BaseSegmentor,
if hasattr(model, 'module'):
model = model.module
if isinstance(img, str):
image = mmcv.imread(img)
image = mmcv.imread(img, channel_order='rgb')
else:
image = img
if save_dir is not None:
Expand Down
14 changes: 7 additions & 7 deletions mmseg/visualization/local_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,14 @@ def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData,

colors = [palette[label] for label in labels]

self.set_image(image)

# draw semantic masks
mask = np.zeros_like(image, dtype=np.uint8)
for label, color in zip(labels, colors):
self.draw_binary_masks(
sem_seg == label, colors=[color], alphas=self.alpha)
mask[sem_seg[0] == label, :] = color

return self.get_image()
color_seg = (image * (1 - self.alpha) + mask * self.alpha).astype(
np.uint8)
self.set_image(color_seg)
return color_seg

def set_dataset_meta(self,
classes: Optional[List] = None,
Expand Down Expand Up @@ -226,6 +226,6 @@ def add_datasample(
self.show(drawn_img, win_name=name, wait_time=wait_time)

if out_file is not None:
mmcv.imwrite(mmcv.bgr2rgb(drawn_img), out_file)
mmcv.imwrite(mmcv.rgb2bgr(drawn_img), out_file)
else:
self.add_image(name, drawn_img, step)

0 comments on commit 8806b4e

Please sign in to comment.