diff --git a/demo/inference_demo.ipynb b/demo/inference_demo.ipynb index d26173fc62..455c5df4e1 100644 --- a/demo/inference_demo.ipynb +++ b/demo/inference_demo.ipynb @@ -21,7 +21,6 @@ "outputs": [], "source": [ "import torch\n", - "import mmcv\n", "import matplotlib.pyplot as plt\n", "from mmengine.model.utils import revert_sync_batchnorm\n", "from mmseg.apis import init_model, inference_model, show_result_pyplot" @@ -48,7 +47,7 @@ "outputs": [], "source": [ "# build the model from a config file and a checkpoint file\n", - "model = init_model(config_file, checkpoint_file, device='cuda:0')" + "model = init_model(config_file, checkpoint_file, device='cpu')" ] }, { @@ -71,8 +70,8 @@ "outputs": [], "source": [ "# show the results\n", - "vis_result = show_result_pyplot(model, img, result)\n", - "plt.imshow(mmcv.bgr2rgb(vis_result))" + "vis_result = show_result_pyplot(model, img, result, show=False)\n", + "plt.imshow(vis_result)" ] }, { @@ -99,7 +98,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.10.11" }, "pycharm": { "stem_cell": { diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index 4aadffc798..81cd17d798 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -187,7 +187,7 @@ def show_result_pyplot(model: BaseSegmentor, if hasattr(model, 'module'): model = model.module if isinstance(img, str): - image = mmcv.imread(img) + image = mmcv.imread(img, channel_order='rgb') else: image = img if save_dir is not None: diff --git a/mmseg/visualization/local_visualizer.py b/mmseg/visualization/local_visualizer.py index 504004dfcb..0d693e5820 100644 --- a/mmseg/visualization/local_visualizer.py +++ b/mmseg/visualization/local_visualizer.py @@ -108,14 +108,14 @@ def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData, colors = [palette[label] for label in labels] - self.set_image(image) - - # draw semantic masks + mask = np.zeros_like(image, dtype=np.uint8) for label, color in zip(labels, colors): - self.draw_binary_masks( - sem_seg == label, colors=[color], alphas=self.alpha) + mask[sem_seg[0] == label, :] = color - return self.get_image() + color_seg = (image * (1 - self.alpha) + mask * self.alpha).astype( + np.uint8) + self.set_image(color_seg) + return color_seg def set_dataset_meta(self, classes: Optional[List] = None, @@ -226,6 +226,6 @@ def add_datasample( self.show(drawn_img, win_name=name, wait_time=wait_time) if out_file is not None: - mmcv.imwrite(mmcv.bgr2rgb(drawn_img), out_file) + mmcv.imwrite(mmcv.rgb2bgr(drawn_img), out_file) else: self.add_image(name, drawn_img, step)