Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: InternVL2-26B infer error:Attempted to assign 7 x 256 = 1792 multimodal tokens to 506 placeholders #7996

Closed
1 task done
SovereignRemedy opened this issue Aug 29, 2024 · 20 comments · Fixed by #8028
Closed
1 task done
Labels
bug Something isn't working

Comments

@SovereignRemedy
Copy link

Your current environment

The output of `python collect_env.py`
Network isolation, unable to download

Python3.8
8*A10 GPU 
Model:InternVL2-26B 
vllm                              0.5.5
vllm-flash-attn                   2.6.1
torch                             2.4.0
torchvision                       0.19.0


🐛 Describe the bug

from dataclasses import dataclass
from typing import Literal

import torch
from PIL import Image


VLM_IMAGES_DIR = "vision_model_images"


@dataclass(frozen=True)
class ImageAsset:
    name: Literal["stop_sign", "cherry_blossom"]

    @property
    def pil_image(self) -> Image.Image:

        image_path = "image.jpg"
        return Image.open(image_path)


"""
This example shows how to use vLLM for running offline inference
with the correct prompt format on vision language models.

For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
from transformers import AutoTokenizer

from vllm import LLM, SamplingParams

# Input image and question
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
question = "What is the content of this image?"


# InternVL
def run_internvl(question):
    model_name = "/home/tdj/model/InternVL2-26B"

    llm = LLM(
        model=model_name,trust_remote_code=True,
        gpu_memory_utilization=0.9,tensor_parallel_size=8
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    messages = [{"role": "user", "content": f"<image>\n{question}"}]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    # Stop tokens for InternVL
    # models variants may have different stop tokens
    # please refer to the model card for the correct "stop words":
    # https://huggingface.co/OpenGVLab/InternVL2-2B#service
    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
    return llm, prompt, stop_token_ids


model_example_map = {
    "internvl_chat": run_internvl,
}


def main():
    model = "internvl_chat"
    if model not in model_example_map:
        raise ValueError(f"Model type {model} is not supported.")

    llm, prompt, stop_token_ids = model_example_map[model](question)

    # We set temperature to 0.2 so that outputs can be different
    # even when all prompts are identical when running batch inference.
    sampling_params = SamplingParams(
        temperature=0.2, max_tokens=64, stop_token_ids=stop_token_ids
    )

    # Single inference
    inputs = {
        "prompt": prompt,
        "multi_modal_data": {"image": image},
    }
    outputs = llm.generate(inputs, sampling_params=sampling_params)

    for o in outputs:
        generated_text = o.outputs[0].text
        print(generated_text)


if __name__ == "__main__":
    main()

here is my error stack trace

(VllmWorkerProcess pid=85388) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]   File "/root/anaconda3/envs/py3/lib/python3.8/site-packages/vllm/model_executor/models/utils.py", line 77, in merge_multimodal_embeddings
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]   File "/root/anaconda3/envs/py3/lib/python3.8/site-packages/vllm/executor/multiproc_worker_utils.py", line 223, in _run_worker_process
(VllmWorkerProcess pid=85388) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]     raise ValueError(
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]     output = executor(*args, **kwargs)
(VllmWorkerProcess pid=85388) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] ValueError: Attempted to assign 7 x 256 = 1792 multimodal tokens to 506 placeholders
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]   File "/root/anaconda3/envs/py3/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=85388) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]     return func(*args, **kwargs)
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]   File "/root/anaconda3/envs/py3/lib/python3.8/site-packages/vllm/worker/worker_base.py", line 69, in start_worker_execution_loop
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]     output = self.execute_model(execute_model_req=None)
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]   File "/root/anaconda3/envs/py3/lib/python3.8/site-packages/vllm/worker/worker_base.py", line 322, in execute_model
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]     output = self.model_runner.execute_model(
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]   File "/root/anaconda3/envs/py3/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]     return func(*args, **kwargs)
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]   File "/root/anaconda3/envs/py3/lib/python3.8/site-packages/vllm/worker/model_runner.py", line 1415, in execute_model
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]     hidden_or_intermediate_states = model_executable(
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]   File "/root/anaconda3/envs/py3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]   File "/root/anaconda3/envs/py3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]   File "/root/anaconda3/envs/py3/lib/python3.8/site-packages/vllm/model_executor/models/internvl.py", line 459, in forward
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]     inputs_embeds = merge_multimodal_embeddings(
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]   File "/root/anaconda3/envs/py3/lib/python3.8/site-packages/vllm/model_executor/models/utils.py", line 77, in merge_multimodal_embeddings
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]     raise ValueError(
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] ValueError: Attempted to assign 7 x 256 = 1792 multimodal tokens to 506 placeholders
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]
(VllmWorkerProcess pid=85392) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] Exception in worker VllmWorkerProcess while processing method start_worker_execution_loop: Attempted to assign 7 x 256 = 1792 multimodal tokens to 506 placeholders, Traceback (most recent call last):

If you have any questions, please feel free to contact me. I will run it exactly according to the official demo. The pictures are from my local
#6321
Is only 2B supported?

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@SovereignRemedy SovereignRemedy added the bug Something isn't working label Aug 29, 2024
@DarkLight1337
Copy link
Member

DarkLight1337 commented Aug 29, 2024

Try setting a larger context length (--max-model-len should be greater than the number of multimodal tokens plus text tokens).

@DarkLight1337
Copy link
Member

Also, make sure your InternVL2 is up-to-date.

@ywang96
Copy link
Member

ywang96 commented Aug 30, 2024

I can actually repro this error - it seems to me that there's something changed about this model that introduced this bug.

@Isotr0py do you have bandwidth to take a look at this issue?

@SovereignRemedy
Copy link
Author

Try setting a larger context length (--max-model-len should be greater than the number of multimodal tokens plus text tokens).

Also, make sure your InternVL2 is up-to-date.

@DarkLight1337
Thank you for your reply. It is already the latest model and the parameter max-model-len=81920 is set. It seems to be the same error so far.

@Isotr0py
Copy link
Contributor

OK, I will take a look at this later today.

@MasterJanus
Copy link

OK, I will take a look at this later today.

I encountered the similar problem on the InternVL2-8B

Error message:
image

Python3.8
1*A100 GPU
vllm 0.5.5
vllm-flash-attn 2.6.1
torch 2.4.0
torchvision 0.19.0

@Isotr0py
Copy link
Contributor

Isotr0py commented Aug 30, 2024

@SovereignRemedy @MasterJanus This may be caused by the chunked prefill. You can set --enable-chunked-prefill=False or enable_chunked_prefill=False to fix this issue.
Update: You can also increase the max_num_batched_tokens to fix this.

WARNING 08-30 16:43:54 arg_utils.py:850] Chunked prefill is enabled by default for models with max_model_len > 32K. Currently, chunked prefill might not work with some features or models. If you encounter any issues, please disable chunked prefill by setting --enable-chunked-prefill=False.
INFO 08-30 16:43:54 config.py:970] Chunked prefill is enabled with max_num_batched_tokens=512.

BTW, inference with max_model_len=4096 works for me because chunked prefill is disabled by default at small model_len.

@ywang96
Copy link
Member

ywang96 commented Aug 30, 2024

@SovereignRemedy @MasterJanus This may be caused by the chunked prefill. You can set --enable-chunked-prefill=False or enable_chunked_prefill=False to fix this issue.

WARNING 08-30 16:43:54 arg_utils.py:850] Chunked prefill is enabled by default for models with max_model_len > 32K. Currently, chunked prefill might not work with some features or models. If you encounter any issues, please disable chunked prefill by setting --enable-chunked-prefill=False.
INFO 08-30 16:43:54 config.py:970] Chunked prefill is enabled with max_num_batched_tokens=512.

BTW, inference with max_model_len=4096 works for me because chunked prefill is disabled by default at small model_len.

@Isotr0py Thanks for the investigation! I guess we never ran into this issue previously since most VL models have small context window. We should definitely make a change to make sure chunked-prefill is disabled when serving a VLM (until we figure out how to make it compatible with chunked prefill)

@Isotr0py
Copy link
Contributor

@ywang96 After deeper investigation, seems that chunked-prefill has no conflicts between VLM. When I increase max_num_batched_tokens=4096, the model also works.

In fact, the root issue is the default max_num_batched_tokens=512 for chunked-prefill too small for VLM. So I think a proper solution is increasing the default max_num_batched_tokens when serving a VLM.

@DarkLight1337
Copy link
Member

I keep forgetting about chunked prefill. Indeed, we should handle this case.

@DarkLight1337
Copy link
Member

@ywang96 After deeper investigation, seems that chunked-prefill has no conflicts between VLM. When I increase max_num_batched_tokens=4096, the model also works.

In fact, the root issue is the default max_num_batched_tokens=512 for chunked-prefill too small for VLM. So I think a proper solution is increasing the default max_num_batched_tokens when serving a VLM.

Let me open a quick PR to increase the default value for VLM.

@ywang96
Copy link
Member

ywang96 commented Aug 30, 2024

In fact, the root issue is the default max_num_batched_tokens=512 for chunked-prefill too small for VLM. So I think a proper solution is increasing the default max_num_batched_tokens when serving a VLM.

I think the tricky part is how to dynamically properly set this number - I'm okay with setting an arbitrary default value for VLMs just for now.

@SovereignRemedy
Copy link
Author

@ywang96 After deeper investigation, seems that chunked-prefill has no conflicts between VLM. When I increase max_num_batched_tokens=4096, the model also works.经过更深入的调查,似乎 chunked-prefill 与 VLM 之间没有冲突。当我增加max_num_batched_tokens=4096时,该模型也可以工作。

In fact, the root issue is the default max_num_batched_tokens=512 for chunked-prefill too small for VLM. So I think a proper solution is increasing the default max_num_batched_tokens when serving a VLM.事实上,根本问题是分块预填充的默认max_num_batched_tokens=512对于 VLM 来说太小。所以我认为正确的解决方案是在服务 VLM 时增加默认的 max_num_batched_tokens 。

@Isotr0py @DarkLight1337

Thank you for your answers. I referred to the above startup parameters, and the inference will not report an error, but there will be similar garbled characters, which I find strange. Is it a problem with my pictures?
image

My code runs on four A10GPU

from dataclasses import dataclass
from typing import Literal


from PIL import Image


VLM_IMAGES_DIR = "vision_model_images"


@dataclass(frozen=True)
class ImageAsset:
    name: Literal["stop_sign", "cherry_blossom"]

    @property
    def pil_image(self) -> Image.Image:

        # image_path = "/home/tdj/model/InternVL2-26B/examples/image1.jpg"
        image_path = "image.jpg"
        return Image.open(image_path)


from transformers import AutoTokenizer

from vllm import LLM, SamplingParams

# Input image and question
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
question = "这个照片是什么内容?"


# InternVL
def run_internvl(question):
    model_name = "/home/tdj/model/InternVL2-26B"

    llm = LLM(
        model=model_name,
        trust_remote_code=True,
        gpu_memory_utilization=0.9,
        tensor_parallel_size=4,
        max_num_batched_tokens=8192,
        max_model_len=4096,
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    messages = [{"role": "user", "content": f"<image>\n{question}"}]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
    return llm, prompt, stop_token_ids


model_example_map = {
    "internvl_chat": run_internvl,
}


def main():
    model = "internvl_chat"
    if model not in model_example_map:
        raise ValueError(f"Model type {model} is not supported.")

    llm, prompt, stop_token_ids = model_example_map[model](question)

    sampling_params = SamplingParams(
        temperature=0.2, max_tokens=64, stop_token_ids=stop_token_ids
    )

    # Single inference
    inputs = {
        "prompt": prompt,
        "multi_modal_data": {"image": image},
    }
    print("!!!loaded model successfully!!!")
    outputs = llm.generate(inputs, sampling_params=sampling_params)
    print("!!!start generate output!!!")
    for o in outputs:
        generated_text = o.outputs[0].text
        print(generated_text)


if __name__ == "__main__":
    main()

@Isotr0py
Copy link
Contributor

Thank you for your answers. I referred to the above startup parameters, and the inference will not report an error, but there will be similar garbled characters, which I find strange. Is it a problem with my pictures?

No, this is a bug about the internlm2 backbone with tensor parallel. I'm working on fixing this. (You can see #8017 for tracking)

@Root970103
Copy link

I met a strange issue. I want to use the InternVL2-8B, I increased the --max_num_batched_tokens to 4096. And I use the tensor_parallel_size=2 to enable tensor parallel. But I got the completely different results from single GPU deployment.

  • single GPU deployment
    python -m vllm.entrypoints.openai.api_server --model OpenGVLab/InternVL2-8B --port 9005 --trust-remote-code -tp 2 --max_num_batched_tokens 4096
    image

  • two GPU deployment
    python -m vllm.entrypoints.openai.api_server --model OpenGVLab/InternVL2-8B --port 9005 --trust-remote-code --max-model-len 10000
    image

@Isotr0py
Copy link
Contributor

Isotr0py commented Sep 2, 2024

@Root970103 I have created #8055 to fix this. Please take a look :)

@Root970103
Copy link

@Root970103 I have created #8055 to fix this. Please take a look :)

Thanks for reply~ I will try again.

@sayakpaul
Copy link

Facing something similar for https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf/ on 4 H100s (vllm installed from source today):

[rank0]: ValueError: Attempted to assign 2340 + 2144 + 1850 + 2160 + 2832 + 2438 + 2340 + 2830 + 2536 + 1948 = 23418 multimodal tokens to 23516 placeholders

@DarkLight1337
Copy link
Member

Facing something similar for https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf/ on 4 H100s (vllm installed from source today):

[rank0]: ValueError: Attempted to assign 2340 + 2144 + 1850 + 2160 + 2832 + 2438 + 2340 + 2830 + 2536 + 1948 = 23418 multimodal tokens to 23516 placeholders

Can you open a separate issue for this since it's for a different model?

@sayakpaul
Copy link

#8421

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants