You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm using the pretrained hateful memes ViLBERT model. I'm getting the following error when I run my code.
[/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py](https://localhost:8080/#) in linear(input, weight, bias)
1670 if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
1671 return handle_torch_function(linear, tens_ops, input, weight, bias=bias)
-> 1672 if input.dim() == 2 and bias is not None:
1673 # fused op is marginally faster
1674 ret = torch.addmm(bias, input, weight.t())
AttributeError: 'NoneType' object has no attribute 'dim'`
Here is my code. I get an error with vilbert, but everything works fine if I replace vilbert with visual_bert.
model_cls = registry.get_model_class('vilbert')
model = model_cls.from_pretrained('vilbert.finetuned.hateful_memes.direct')
Vil = Inference('/root/.cache/torch/mmf/data/models/vilbert.finetuned.hateful_memes.direct')
path = "" #intentionally left out
text = "" #intentionally left out
im, im_scale = Vil._image_transform(path)
output = Vil.forward(path, {'text': text}, im,im_scale)
scores = output['scores']
pred_label = torch.argmax(scores, dim = 1)
pred_l = pred_label.item()
print(pred_l)`
import numpy as np
import requests
import torch
from mmf.common.report import Report
from mmf.common.sample import Sample, SampleList
from mmf.utils.build import build_encoder, build_model, build_processors
from mmf.utils.checkpoint import load_pretrained_model
from mmf.utils.general import get_current_device
from omegaconf import OmegaConf
from PIL import Image
import cv2
from mmf.common.registry import registry # nvishwam
from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.layers import nms
from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.structures.image_list import to_image_list
from maskrcnn_benchmark.utils.model_serialization import load_state_dict
from PIL import *
class Inference:
def __init__(self, checkpoint_path: str = None):
self.checkpoint = checkpoint_path
self.processor, self.model = self._build_model()
self.detection_model = self._get_detection_model()
def _build_model(self):
model_cls = registry.get_model_class('vilbert')
model = model_cls.from_pretrained('vilbert.finetuned.hateful_memes.direct')
self.model_items = load_pretrained_model(self.checkpoint)
self.config = OmegaConf.create(self.model_items["full_config"])
dataset_name = list(self.config.dataset_config.keys())[0]
processor = build_processors(
self.config.dataset_config[dataset_name].processors
)
return processor, model
def get_actual_image(self, image_path):
if image_path.startswith('http'):
path = requests.get(image_path, stream=True).raw
else:
path = image_path
return path
def _image_transform(self, image_path):
path = self.get_actual_image(image_path)
img = Image.open(path).convert("RGB")
im = np.array(img).astype(np.float32)
print(im.shape)
im = im[:, :, ::-1]
im -= np.array([102.9801, 115.9465, 122.7717])
im_shape = im.shape
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])
im_scale = float(800) / float(im_size_min)
# Prevent the biggest axis from being more than max_size
if np.round(im_scale * im_size_max) > 1333:
im_scale = float(1333) / float(im_size_max)
im = cv2.resize(
im,
None,
None,
fx=im_scale,
fy=im_scale,
interpolation=cv2.INTER_LINEAR
)
img = torch.from_numpy(im).permute(2, 0, 1)
return img, im_scale
def _get_detection_model(self):
cfg.merge_from_file('/content/model_data/detectron_model.yaml')
cfg.freeze()
model = build_detection_model(cfg)
checkpoint = torch.load('/content/model_data/detectron_model.pth',map_location=torch.device("cpu"))
load_state_dict(model, checkpoint.pop("model"))
model.to("cuda")
model.eval()
return model
def _process_feature_extraction(self, output,
im_scales,
feat_name='fc6',
conf_thresh=0.2):
batch_size = len(output[0]["proposals"])
n_boxes_per_image = [len(_) for _ in output[0]["proposals"]]
score_list = output[0]["scores"].split(n_boxes_per_image)
score_list = [torch.nn.functional.softmax(x, -1) for x in score_list]
feats = output[0][feat_name].split(n_boxes_per_image)
cur_device = score_list[0].device
feat_list = []
for i in range(batch_size):
dets = output[0]["proposals"][i].bbox / im_scales[i]
scores = score_list[i]
max_conf = torch.zeros((scores.shape[0])).to(cur_device)
for cls_ind in range(1, scores.shape[1]):
cls_scores = scores[:, cls_ind]
keep = nms(dets, cls_scores, 0.5)
max_conf[keep] = torch.where(cls_scores[keep] > max_conf[keep],
cls_scores[keep],
max_conf[keep])
keep_boxes = torch.argsort(max_conf, descending=True)[:100]
feat_list.append(feats[i][keep_boxes])
return feat_list
def get_detectron_features(self, image_path, im, im_scale):
#im, im_scale = self._image_transform(image_path)
img_tensor, im_scales = [im], [im_scale]
current_img_list = to_image_list(img_tensor, size_divisible=32)
current_img_list = current_img_list.to('cuda')
with torch.no_grad():
output = self.detection_model(current_img_list)
feat_list = self._process_feature_extraction(output, im_scales, 'fc6', 0.2)
return feat_list[0]
def forward(self, image_path: str, text: dict, im, im_scale, image_format: str = "path"):
text_output = self.processor["text_processor"](text)
if image_format == "path":
img = Image.open(image_path)
elif image_format == "url":
img = np.array(Image.open(requests.get(image_path, stream=True).raw))
image_output = self.get_detectron_features(image_path, im, im_scale)
sample = Sample(text_output)
sample.image_feature_0 = image_output
sample.dataset_name = 'hateful_memes'
sample.image_location = im
sample_list = SampleList([sample])
sample_list = sample_list.to(get_current_device())
self.model = self.model.to(get_current_device())
output = self.model(sample_list)
return output
The text was updated successfully, but these errors were encountered:
In your funtion def _get_detection_model(self):
What is '/content/model_data/detectron_model.yaml' file and /content/model_data/detectron_model.pth model??
I'm using the pretrained hateful memes ViLBERT model. I'm getting the following error when I run my code.
Here is my code. I get an error with vilbert, but everything works fine if I replace vilbert with visual_bert.
The text was updated successfully, but these errors were encountered: