Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: some bugs when text input reaches max_tokens of language_model #11669

Open
wants to merge 2 commits into
base: dev-3.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions mmdet/datasets/transforms/text_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def check_for_positive_overflow(gt_bboxes, gt_labels, text, tokenizer,
keep_gt_labels.append(gt_labels[i])

return gt_bboxes[keep_box_index], np.array(
keep_gt_labels, dtype=np.long), length
keep_gt_labels, dtype=np.long), length, keep_box_index


def generate_senetence_given_labels(positive_label_list, negative_label_list,
Expand Down Expand Up @@ -164,7 +164,7 @@ def od_aug(self, results):
if '/' in value:
text[key] = random.choice(value.split('/')).strip()

gt_bboxes, gt_labels, positive_caption_length = \
gt_bboxes, gt_labels, positive_caption_length, keep_box_index = \
check_for_positive_overflow(gt_bboxes, gt_labels,
text, self.tokenizer, self.max_tokens)

Expand Down Expand Up @@ -217,7 +217,7 @@ def od_aug(self, results):

negative_max_length -= len(tokenized)

if negative_max_length > 0:
if negative_max_length >= 0:
screened_negative_label_list.append(negative_label)
else:
break
Expand All @@ -232,6 +232,8 @@ def od_aug(self, results):

results['gt_bboxes'] = gt_bboxes
results['gt_bboxes_labels'] = gt_labels
if results.get('gt_ignore_flags', None) is not None:
results['gt_ignore_flags'] = results['gt_ignore_flags'][keep_box_index]

results['text'] = pheso_caption
results['tokens_positive'] = label_to_positions
Expand Down
21 changes: 13 additions & 8 deletions mmdet/models/detectors/grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ def get_tokens_and_prompts(
[caption_string],
padding='max_length'
if self.language_model.pad_to_max else 'longest',
return_tensors='pt')
return_tensors='pt',
add_special_tokens=False)
entities = original_caption
else:
if not original_caption.endswith('.'):
Expand All @@ -175,7 +176,8 @@ def get_tokens_and_prompts(
[original_caption],
padding='max_length'
if self.language_model.pad_to_max else 'longest',
return_tensors='pt')
return_tensors='pt',
add_special_tokens=False)
tokens_positive, noun_phrases = run_ner(original_caption)
entities = noun_phrases
caption_string = original_caption
Expand Down Expand Up @@ -224,7 +226,8 @@ def get_tokens_positive_and_prompts(
[original_caption],
padding='max_length'
if self.language_model.pad_to_max else 'longest',
return_tensors='pt')
return_tensors='pt',
add_special_tokens=False)
positive_map_label_to_token, positive_map = \
self.get_positive_map(tokenized, tokens_positive)

Expand Down Expand Up @@ -281,7 +284,8 @@ def get_tokens_positive_and_prompts_chunked(
caption_string, tokens_positive = self.to_plain_text_prompts(
original_caption_chunked[i])
tokenized = self.language_model.tokenizer([caption_string],
return_tensors='pt')
return_tensors='pt',
add_special_tokens=False)
if tokenized.input_ids.shape[1] > self.language_model.max_tokens:
warnings.warn('Inputting a text that is too long will result '
'in poor prediction performance. '
Expand Down Expand Up @@ -439,7 +443,8 @@ def loss(self, batch_inputs: Tensor,
[text_prompt],
padding='max_length'
if self.language_model.pad_to_max else 'longest',
return_tensors='pt')
return_tensors='pt',
add_special_tokens=False)
new_tokens_positive = [
token_positive[label.item()] for label in gt_label
]
Expand Down Expand Up @@ -477,7 +482,7 @@ def loss(self, batch_inputs: Tensor,
positive_maps.append(positive_map)
new_text_prompts.append(caption_string)

text_dict = self.language_model(new_text_prompts)
text_dict = self.language_model(new_text_prompts, add_special_tokens=False)
if self.text_feat_map is not None:
text_dict['embedded'] = self.text_feat_map(text_dict['embedded'])

Expand Down Expand Up @@ -553,7 +558,7 @@ def predict(self, batch_inputs, batch_data_samples, rescale: bool = True):
for b in range(len(text_prompts[0])):
text_prompts_once = [text_prompts[0][b]]
token_positive_maps_once = token_positive_maps[0][b]
text_dict = self.language_model(text_prompts_once)
text_dict = self.language_model(text_prompts_once, add_special_tokens=False)
# text feature map layer
if self.text_feat_map is not None:
text_dict['embedded'] = self.text_feat_map(
Expand All @@ -577,7 +582,7 @@ def predict(self, batch_inputs, batch_data_samples, rescale: bool = True):
is_rec_tasks = [False] * len(results_list)
else:
# extract text feats
text_dict = self.language_model(list(text_prompts))
text_dict = self.language_model(list(text_prompts), add_special_tokens=False)
# text feature map layer
if self.text_feat_map is not None:
text_dict['embedded'] = self.text_feat_map(
Expand Down
10 changes: 7 additions & 3 deletions mmdet/models/language_models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def generate_masks_with_special_tokens_and_transfer_map(
device=input_ids.device).bool().unsqueeze(0).repeat(
bs, 1, 1))
position_ids = torch.zeros((bs, num_token), device=input_ids.device)
previous_col = 0
previous_col = -1
for i in range(idxs.shape[0]):
row, col = idxs[i]
if (col == 0) or (col == num_token - 1):
if col == 0:
attention_mask[row, col, col] = True
position_ids[row, col] = 0
else:
Expand All @@ -68,6 +68,9 @@ def generate_masks_with_special_tokens_and_transfer_map(
0, col - previous_col, device=input_ids.device)
previous_col = col

if i + 1 != idxs.shape[0] and idxs[i + 1][0] != idxs[i][0]:
previous_col = -1

return attention_mask, position_ids.to(torch.long)


Expand Down Expand Up @@ -143,7 +146,8 @@ def forward(self, captions: Sequence[str], **kwargs) -> dict:
padding='max_length' if self.pad_to_max else 'longest',
return_special_tokens_mask=True,
return_tensors='pt',
truncation=True).to(device)
truncation=True,
add_special_tokens=kwargs.get("add_special_tokens", True)).to(device)
input_ids = tokenized.input_ids
if self.use_sub_sentence_represent:
attention_mask, position_ids = \
Expand Down