Skip to content

Commit

Permalink
[Model] mPLUG-Owl3 Added (#403)
Browse files Browse the repository at this point in the history
* finish inference; need modify prompt_full

* add demo.py

* finish owl3

* remove demo.py

* remove mPLUG-Owl3 result

* finish lint

* change generate_inner return content

* add image process

* [Fix] delete log

* [Fix] lint

* [Fix] tools

* fix config

---------

Co-authored-by: kennymckormick <[email protected]>
  • Loading branch information
SYuan03 and kennymckormick committed Sep 7, 2024
1 parent 1a5ec34 commit d1f1f65
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 1 deletion.
3 changes: 2 additions & 1 deletion vlmeval/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
'flamingov2': partial(OpenFlamingo, name='v2', mpt_pth='anas-awadalla/mpt-7b', ckpt_pth='openflamingo/OpenFlamingo-9B-vitl-mpt7b'),
'VisualGLM_6b': partial(VisualGLM, model_path='THUDM/visualglm-6b'),
'mPLUG-Owl2': partial(mPLUG_Owl2, model_path='MAGAer13/mplug-owl2-llama2-7b'),
'mPLUG-Owl3': partial(mPLUG_Owl3, model_path='mPLUG/mPLUG-Owl3-7B-240728'),
'emu2_chat': partial(Emu, model_path='BAAI/Emu2-Chat'),
'OmniLMM_12B': partial(OmniLMM12B, model_path='openbmb/OmniLMM-12B', root=OmniLMM_ROOT),
'MGM_7B': partial(Mini_Gemini, model_path='YanweiLi/MGM-7B-HD', root=Mini_Gemini_ROOT),
Expand Down Expand Up @@ -244,7 +245,7 @@
'Qwen2-VL-2B-Instruct-GPTQ-Int4': partial(Qwen2VLChat, model_path='Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4', min_pixels=1280*28*28, max_pixels=5120*28*28),
'Qwen2-VL-2B-Instruct-GPTQ-Int8': partial(Qwen2VLChat, model_path='Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8', min_pixels=1280*28*28, max_pixels=5120*28*28),
}

phi3_series = {
'Phi-3-Vision': partial(Phi3Vision, model_path='microsoft/Phi-3-vision-128k-instruct'),
'Phi-3.5-Vision': partial(Phi3_5Vision, model_path='microsoft/Phi-3.5-vision-instruct')
Expand Down
1 change: 1 addition & 0 deletions vlmeval/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
'4.37.0': [x for x in llava_series if 'next' not in x] + list(internvl_series) + [
'TransCore_M', 'emu2_chat', 'MiniCPM-V', 'MiniCPM-V-2', 'OmniLMM_12B',
'cogvlm-grounding-generalist', 'cogvlm-chat', 'cogvlm2-llama3-chat-19B',
'mPLUG-Owl3'
] + list(xtuner_series) + list(yivl_series) + list(deepseekvl_series) + list(cambrian_series),
'4.40.0': [
'idefics2_8b', 'Bunny-llama3-8B', 'MiniCPM-Llama3-V-2_5', '360VL-70B', 'Phi-3-Vision',
Expand Down
1 change: 1 addition & 0 deletions vlmeval/vlm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@
from .omchat import OmChat
from .rbdash import RBDash
from .xgen_mm import XGenMM
from .mplug_owl3 import mPLUG_Owl3
124 changes: 124 additions & 0 deletions vlmeval/vlm/mplug_owl3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import torch
from .base import BaseModel
from ..smp import *
from ..dataset import DATASET_TYPE
from transformers import AutoTokenizer, AutoModel


class mPLUG_Owl3(BaseModel):
# No separate model module is required, but the dependencies must be met.
# https://github.com/X-PLUG/mPLUG-Owl/blob/main/mPLUG-Owl3/requirements.txt
INSTALL_REQ = True
INTERLEAVE = True
INSTALL_REQ_TXT = 'https://github.com/X-PLUG/mPLUG-Owl/blob/main/mPLUG-Owl3/requirements.txt'

def __init__(self, model_path=None, **kwargs):
assert model_path is not None
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True
)

self.model = AutoModel.from_pretrained(
model_path,
attn_implementation='flash_attention_2',
torch_dtype=torch.half,
trust_remote_code=True
)
self.model.eval().cuda()
self.processor = self.model.init_processor(self.tokenizer)
self.logger = get_logger('mPLUG_Owl3')
if self.INSTALL_REQ:
self.logger.info(
f'Please remember to meet the requirements first\n'
f'Here: {self.INSTALL_REQ_TXT}'
)

def use_custom_prompt(self, dataset):
assert dataset is not None
if listinstr(['MMMU'], dataset):
return False
if DATASET_TYPE(dataset) == 'MCQ' or dataset == 'MMVet':
return True
return False

# Currently same to mPLUG_Owl2
def build_prompt(self, line, dataset=None):
assert dataset is None or isinstance(dataset, str)
assert self.use_custom_prompt(dataset)
tgt_path = self.dump_image(line, dataset)
question = line['question']
if dataset == 'MMVet':
prompt = question + '\nAnswer the question directly. '
elif DATASET_TYPE(dataset) == 'MCQ':
options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
options_prompt = ''
for key, item in options.items():
options_prompt += f'{key}. {item}\n'

hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
prompt = f'Hint: {hint}\n' if hint is not None else ''
prompt += f'{question}\n'
prompt += (
f'{options_prompt}\nAnswer with the option’s letter from the given choices directly. '
if len(options) else 'Answer the question directly. '
)
else:
raise NotImplementedError

message = [dict(type='text', value=prompt)]
message.extend([dict(type='image', value=s) for s in tgt_path])
return message

def preproc_image(self, fname):
from PIL import Image
image = Image.open(fname).convert('RGB')
# resize to max_size
max_size = 448 * 16
if max(image.size) > max_size:
w, h = image.size
if w > h:
new_w = max_size
new_h = int(h * max_size / w)
else:
new_h = max_size
new_w = int(w * max_size / h)
image = image.resize((new_w, new_h), resample=Image.BICUBIC)
return image

def generate_inner(self, message, dataset=None):
num_images = len([x for x in message if x['type'] == 'image'])
assert num_images >= 0

images = []
prompt_full = ''

for msg in message:
if msg['type'] == 'image':
images.append(msg['value'])
prompt_full += '<|image|>'
elif msg['type'] == 'text':
prompt_full += msg['value']

needed_messages = [
{'role': 'user', 'content': prompt_full},
{'role': 'assistant', 'content': ''}
]

images = [self.preproc_image(fname) for fname in images]

inputs = self.processor(needed_messages, images=images, videos=None)

inputs.to('cuda')
inputs.update({
'tokenizer': self.tokenizer,
'max_new_tokens': 1024,
'decode_text': True,
})

g = self.model.generate(**inputs)
return g[0]

0 comments on commit d1f1f65

Please sign in to comment.