Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid missing packages and attn_mask dtype error #992

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

mvsoom
Copy link

@mvsoom mvsoom commented Jul 25, 2024

I installed the repo for Visualized-BGE following the instructions at FlagEmbedding/visual/README.md on CPU. I downloaded the weights from HF. When executing the example code in the README:

####### Use Visualized BGE doing multi-modal knowledge retrieval
import torch
from FlagEmbedding.visual.modeling import Visualized_BGE

model = Visualized_BGE(model_name_bge = "BAAI/bge-base-en-v1.5", model_weight="path: Visualized_base_en_v1.5.pth")
model.eval()
with torch.no_grad():
    query_emb = model.encode(text="Are there sidewalks on both sides of the Mid-Hudson Bridge?")
    candi_emb_1 = model.encode(text="The Mid-Hudson Bridge, spanning the Hudson River between Poughkeepsie and Highland.", image="./imgs/wiki_candi_1.jpg")
    candi_emb_2 = model.encode(text="Golden_Gate_Bridge", image="./imgs/wiki_candi_2.jpg")
    candi_emb_3 = model.encode(text="The Mid-Hudson Bridge was designated as a New York State Historic Civil Engineering Landmark by the American Society of Civil Engineers in 1983. The bridge was renamed the \"Franklin Delano Roosevelt Mid-Hudson Bridge\" in 1994.")

sim_1 = query_emb @ candi_emb_1.T
sim_2 = query_emb @ candi_emb_2.T
sim_3 = query_emb @ candi_emb_3.T
print(sim_1, sim_2, sim_3) # tensor([[0.6932]]) tensor([[0.4441]]) tensor([[0.6415]])

I got two errors:

1. Missing packages. peft and sentencepiece, the former for BAAI/bge-base-en-v1.5 and the latter for BAAI/bge-m3. I added those to setup.py. When pip installing them, all was well. Note: I have no experience with setup.py based installations, so best check if this is correct.

2. Dtype mismatch: happens when encoding only text, without images.

RuntimeError: Expected attn_mask dtype to be bool or float or to match query dtype, but got attn_mask.dtype: c10::Half and  query.dtype: float instead.

This is solved by ensuring extended_attention_mask = extended_attention_mask.to(embedding_output.dtype), I added this in modeling.py:205. After that, all is well and the 3 numerical values of the similarities at the end of the above code snippet are reproduced.

Would be nice if you can merge this so I don't have to rely on my own fork for further work! Cheers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant