diff --git a/mmpretrain/datasets/__init__.py b/mmpretrain/datasets/__init__.py index b7b6be47dce..1deb59ec6a8 100644 --- a/mmpretrain/datasets/__init__.py +++ b/mmpretrain/datasets/__init__.py @@ -46,6 +46,7 @@ from .refcoco import RefCOCO from .scienceqa import ScienceQA from .textvqa import TextVQA + from .visdial import VisDial from .visual_genome import VisualGenomeQA from .vizwiz import VizWiz from .vsr import VSR @@ -54,5 +55,5 @@ 'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption', 'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval', 'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA', - 'VSR', 'VizWiz', 'OCRVQA' + 'VSR', 'VizWiz', 'OCRVQA', 'VisDial' ]) diff --git a/mmpretrain/datasets/visdial.py b/mmpretrain/datasets/visdial.py new file mode 100644 index 00000000000..66f3379f8f8 --- /dev/null +++ b/mmpretrain/datasets/visdial.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset +from mmengine.fileio import get_file_backend + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class VisDial(BaseDataset): + """VisDial dataset. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + question_file (str): Question file path. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to an empty string. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str, + ann_file: str = '', + **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + annotations = mmengine.load(self.ann_file)['data'] + + dialogs = annotations['dialogs'] + answers = annotations['answers'] + questions = annotations['questions'] + + data_list = [] + + for dialog in dialogs: + image_id = dialog['image_id'] + caption = dialog['caption'] + + historys = ['Caption:' + caption + '.'] + + for i in range(1, len(dialog['dialog'])): + historys.append('') + + previous_idx = i - 1 + # for j in range(i): + question_id = dialog['dialog'][previous_idx]['question'] + answer_id = dialog['dialog'][previous_idx]['answer'] + + history = ' Question:{question}? Answer:{answer}.' \ + .format(question=questions[question_id], + answer=answers[answer_id]) + + historys[i] = historys[previous_idx] + history + + # get question and answer options for each dialog round + for dialog_id, dialog_round in enumerate(dialog['dialog']): + question_id = dialog_round['question'] + answer_id = dialog_round['answer'] + answer_options = [ + answers[answer_id] + for answer_id in dialog_round['answer_options'] + ] + + data_info = dict(image_id=image_id) + + img_prefix = self.data_prefix['img_path'] + file_backend = get_file_backend(img_prefix) + + data_info['img_path'] = file_backend.join_path( + img_prefix, + img_prefix.split('/')[-1] + '_' + str(image_id).zfill(12) + + '.jpg') + + data_info['dialog_history'] = historys[dialog_id] + + data_info['question'] = questions[question_id] + '?' + data_info['answer'] = answers[answer_id] + data_info['answer_options'] = answer_options + data_info['gt_answer_index'] = data_info[ + 'answer_options'].index(data_info['answer']) + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/evaluation/metrics/__init__.py b/mmpretrain/evaluation/metrics/__init__.py index 7f5a4f36b41..450a0ecba8d 100644 --- a/mmpretrain/evaluation/metrics/__init__.py +++ b/mmpretrain/evaluation/metrics/__init__.py @@ -7,6 +7,7 @@ from .retrieval import RetrievalAveragePrecision, RetrievalRecall from .scienceqa import ScienceQAMetric from .single_label import Accuracy, ConfusionMatrix, SingleLabelMetric +from .visual_dialog import SparseGTMetrics from .visual_grounding_eval import VisualGroundingMetric from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric from .vqa import ReportVQA, VQAAcc @@ -16,5 +17,5 @@ 'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric', 'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption', 'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc', 'NocapsSave', - 'RetrievalAveragePrecision' + 'RetrievalAveragePrecision', 'SparseGTMetrics' ] diff --git a/mmpretrain/evaluation/metrics/visual_dialog.py b/mmpretrain/evaluation/metrics/visual_dialog.py new file mode 100644 index 00000000000..acfe0c5b2b7 --- /dev/null +++ b/mmpretrain/evaluation/metrics/visual_dialog.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch +from mmengine.evaluator import BaseMetric + +from mmpretrain.evaluation.metrics.vqa import (_process_digit_article, + _process_punctuation) +from mmpretrain.registry import METRICS + + +@METRICS.register_module() +class SparseGTMetrics(BaseMetric): + """Visual Dialog Acc metric. + + Compute Visual Dialogaccuracy. + + Args: + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + """ + default_prefix = 'Visual Dialog' + + def __init__(self, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + + def process(self, data_batch, data_samples) -> None: + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for sample in data_samples: + answer_options = sample.get('answer_options') + + G = torch.Generator() + G.manual_seed(0) + rank = 1 + torch.randperm(len(answer_options), generator=G) + + pred_answer = sample.get('pred_answer') + + if pred_answer in answer_options: + answer_index = answer_options.index(pred_answer) + rank[answer_index] = 1 + + gt_index = sample.get('gt_answer_index') + gt_rank = rank[gt_index] + + self.results.append(gt_rank) + + def compute_metrics(self, results: List) -> dict: + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + + R1 = (torch.tensor(results) <= 1).float().mean() + R5 = (torch.tensor(results) <= 5).float().mean() + R10 = (torch.tensor(results) <= 10).float().mean() + Mean = torch.tensor(results).float().mean() + MRR = torch.tensor(results).reciprocal().mean() + + metrics = { + 'R@1': R1.item(), + 'R@5': R5.item(), + 'R@10': R10.item(), + 'Mean': Mean.item(), + 'MRR': MRR.item() + } + return metrics + + def _process_answer(self, answer) -> str: + answer = _process_punctuation(answer) + answer = _process_digit_article(answer) + return answer