Skip to content

Commit

Permalink
[Fix] fix import error raised by ldm (#3338)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Louis committed Sep 20, 2023
1 parent 56a40d7 commit 9c45a94
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 15 deletions.
16 changes: 14 additions & 2 deletions mmseg/models/backbones/vpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,19 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ldm.modules.diffusionmodules.util import timestep_embedding
from ldm.util import instantiate_from_config
from mmengine.model import BaseModule
from mmengine.runner import CheckpointLoader, load_checkpoint

from mmseg.registry import MODELS
from mmseg.utils import ConfigType, OptConfigType

try:
from ldm.modules.diffusionmodules.util import timestep_embedding
from ldm.util import instantiate_from_config
has_ldm = True
except ImportError:
has_ldm = False


def register_attention_control(model, controller):
"""Registers a control function to manage attention within a model.
Expand Down Expand Up @@ -205,6 +210,10 @@ def __init__(self,
max_attn_size=None,
attn_selector='up_cross+down_cross'):
super().__init__()

assert has_ldm, 'To use UNetWrapper, please install required ' \
'packages via `pip install -r requirements/optional.txt`.'

self.unet = unet
self.attention_store = AttentionStore(
base_size=base_size // 8, max_size=max_attn_size)
Expand Down Expand Up @@ -321,6 +330,9 @@ def __init__(self,

super().__init__(init_cfg=init_cfg)

assert has_ldm, 'To use VPD model, please install required packages' \
' via `pip install -r requirements/optional.txt`.'

if pad_shape is not None:
if not isinstance(pad_shape, (list, tuple)):
pad_shape = (pad_shape, pad_shape)
Expand Down
19 changes: 6 additions & 13 deletions tools/analysis_tools/visualization_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
import numpy as np
import torch
import torch.nn.functional as F
from mmengine import Config
from mmengine.model import revert_sync_batchnorm
from PIL import Image
from pytorch_grad_cam import GradCAM, LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image

from mmengine import Config
from mmseg.apis import inference_model, init_model, show_result_pyplot
from mmseg.utils import register_all_modules

Expand Down Expand Up @@ -56,21 +56,15 @@ def main():
default='prediction.png',
help='Path to output prediction file')
parser.add_argument(
'--cam-file',
default='vis_cam.png',
help='Path to output cam file')
'--cam-file', default='vis_cam.png', help='Path to output cam file')
parser.add_argument(
'--target-layers',
default='backbone.layer4[2]',
help='Target layers to visualize CAM')
parser.add_argument(
'--category-index',
default='7',
help='Category to visualize CAM')
'--category-index', default='7', help='Category to visualize CAM')
parser.add_argument(
'--device',
default='cuda:0',
help='Device used for inference')
'--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()

# build the model from a config file and a checkpoint file
Expand Down Expand Up @@ -116,8 +110,7 @@ def main():
# Grad CAM(Class Activation Maps)
# Can also be LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM
targets = [
SemanticSegmentationTarget(category, mask_float,
(height, width))
SemanticSegmentationTarget(category, mask_float, (height, width))
]
with GradCAM(
model=model,
Expand Down

0 comments on commit 9c45a94

Please sign in to comment.