Diff of /lavis/tasks/vqa.py [000000] .. [dc40d0]

Switch to side-by-side view

--- a
+++ b/lavis/tasks/vqa.py
@@ -0,0 +1,319 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import logging
+import json
+import os
+
+import lavis.common.dist_utils as dist_utils
+from lavis.common.registry import registry
+from lavis.common.vqa_tools.vqa import VQA
+from lavis.common.vqa_tools.vqa_eval import VQAEval
+from lavis.tasks.base_task import BaseTask
+
+
+@registry.register_task("vqa")
+class VQATask(BaseTask):
+    def __init__(
+        self,
+        num_beams,
+        max_len,
+        min_len,
+        evaluate,
+        num_ans_candidates,
+        inference_method="rank",
+        prompt="",
+    ):
+        super().__init__()
+
+        self.num_beams = num_beams
+        self.max_len = max_len
+        self.min_len = min_len
+
+        self.evaluate = evaluate
+        self.inference_method = inference_method
+        self.num_ans_candidates = num_ans_candidates
+        self.prompt = prompt
+
+        self.answer_list = None
+
+        self.ques_files = dict()
+        self.anno_files = dict()
+
+    @classmethod
+    def setup_task(cls, cfg):
+        run_cfg = cfg.run_cfg
+
+        num_beams = run_cfg.get("num_beams", 3)
+        max_len = run_cfg.get("max_len", 10)
+        min_len = run_cfg.get("min_len", 1)
+
+        evaluate = run_cfg.get("evaluate", False)
+
+        inference_method = run_cfg.get("inference_method", "rank")
+        num_ans_candidates = run_cfg.get("num_ans_candidates", 128)
+        prompt = run_cfg.get("prompt", "")
+
+        return cls(
+            num_beams=num_beams,
+            max_len=max_len,
+            min_len=min_len,
+            evaluate=evaluate,
+            num_ans_candidates=num_ans_candidates,
+            inference_method=inference_method,
+            prompt=prompt,
+        )
+
+    def build_datasets(self, cfg):
+        datasets = super().build_datasets(cfg)
+
+        # get question file, annotation file and anwser list in COCO format
+        for dataset in datasets.values():
+            for split in dataset:
+                if (
+                    hasattr(dataset[split], "coco_fmt_qust_file")
+                    and dataset[split].coco_fmt_qust_file is not None
+                ):
+                    self.ques_files[split] = dataset[split].coco_fmt_qust_file
+                    self.anno_files[split] = dataset[split].coco_fmt_anno_file
+
+                try:
+                    self.answer_list = dataset[split].answer_list
+                except AttributeError:
+                    # if answer_list is not provided, then set it to None
+                    pass
+
+        if len(self.ques_files) > 0:
+            assert len(self.ques_files) == len(
+                self.anno_files
+            ), "Only support one split for evaluation."
+
+        return datasets
+
+    def valid_step(self, model, samples):
+        answers = model.predict_answers(
+            samples=samples,
+            answer_list=self.answer_list,
+            inference_method=self.inference_method,
+            num_beams=self.num_beams,
+            max_len=self.max_len,
+            min_len=self.min_len,
+            num_ans_candidates=self.num_ans_candidates,
+            prompt=self.prompt,
+        )
+        pred_qa_pairs = []
+
+        question_id = samples["text_output"]
+        for ques, answer, ques_id in zip(samples['text_input'], answers, question_id):
+            # ques_id = int(ques_id.item())
+            pred_qa_pairs.append({"question": ques, "question_ans": ques_id, "predict_ans": answer})
+            print("Question: ", ques)
+            print("Predict_ans: ", answer)
+            print("Question_ans: ", ques_id)
+            print("####################")
+            
+
+        return pred_qa_pairs
+
+    def after_evaluation(self, val_result, split_name, **kwargs):
+        result_file = self.save_result(
+            val_result,
+            result_dir=registry.get_path("result_dir"),
+            filename=f"{split_name}_vqa_result",
+            remove_duplicate="question_id",
+        )
+
+        metrics = self._report_metrics(result_file=result_file, split=split_name)
+
+        return metrics
+
+    @dist_utils.main_process
+    def _report_metrics(self, result_file, split):
+        """
+        Use official VQA evaluation script to report metrics.
+        """
+        metrics = {}
+
+        if split in self.ques_files and split in self.anno_files:
+            vqa = VQA(self.anno_files[split], self.ques_files[split])
+            vqa_result = vqa.loadRes(
+                resFile=result_file, quesFile=self.ques_files[split]
+            )
+
+            # create vqaEval object by taking vqa and vqaRes
+            # n is precision of accuracy (number of places after decimal), default is 2
+            vqa_scorer = VQAEval(vqa, vqa_result, n=2)
+            logging.info("Start VQA evaluation.")
+            vqa_scorer.evaluate()
+
+            # print accuracies
+            overall_acc = vqa_scorer.accuracy["overall"]
+            metrics["agg_metrics"] = overall_acc
+
+            logging.info("Overall Accuracy is: %.02f\n" % overall_acc)
+            logging.info("Per Answer Type Accuracy is the following:")
+
+            for ans_type in vqa_scorer.accuracy["perAnswerType"]:
+                logging.info(
+                    "%s : %.02f"
+                    % (ans_type, vqa_scorer.accuracy["perAnswerType"][ans_type])
+                )
+                metrics[ans_type] = vqa_scorer.accuracy["perAnswerType"][ans_type]
+
+            with open(
+                os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
+            ) as f:
+                f.write(json.dumps(metrics) + "\n")
+
+        return metrics
+
+@registry.register_task("gqa")
+class GQATask(VQATask):
+    def valid_step(self, model, samples):
+        answers = model.predict_answers(
+            samples=samples,
+            answer_list=self.answer_list,
+            inference_method=self.inference_method,
+            num_beams=self.num_beams,
+            max_len=self.max_len,
+            min_len=self.min_len,
+            num_ans_candidates=self.num_ans_candidates,
+            prompt=self.prompt,
+        )
+        pred_qa_pairs = []
+
+        question_id = samples["question_id"]
+        gt_answers = samples["answer"]
+        
+        for answer, ques_id, gt_answer in zip(answers, question_id, gt_answers):
+            ques_id = int(ques_id.item())
+            pred_qa_pairs.append({"question_id": ques_id, "pred_ans": answer, "gt_ans": gt_answer})
+
+        return pred_qa_pairs
+        
+    @dist_utils.main_process
+    def _report_metrics(self, result_file, split):
+        """
+        TODO: add other evaluation metrics for GQA
+        """
+
+        results = json.load(open(result_file, "r"))
+        acc = []
+        vqa_tool = VQAEval()
+
+        for res in results:
+            if res["gt_ans"] is None:
+                # prepare test results for leaderboard evaluation
+                self._save_result_leaderboard(results)
+                return
+
+            gt_ans = res["gt_ans"]
+            pred = res["pred_ans"]
+
+            # if self.inference_method == "generate":
+            pred = vqa_tool.processPunctuation(pred)
+            pred = vqa_tool.processDigitArticle(pred)
+
+            vqa_acc = 1 if pred == gt_ans else 0
+
+            acc.append(vqa_acc)
+
+        accuracy = sum(acc) / len(acc) * 100
+        metrics = {"agg_metrics": accuracy, "acc": accuracy}
+
+        with open(
+            os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
+        ) as f:
+            f.write(json.dumps(metrics) + "\n")
+
+        logging.info(metrics)
+
+        return metrics
+        
+
+@registry.register_task("aok_vqa")
+class AOKVQATask(VQATask):
+    def valid_step(self, model, samples):
+        answers = model.predict_answers(
+            samples=samples,
+            answer_list=self.answer_list,
+            inference_method=self.inference_method,
+            num_beams=self.num_beams,
+            max_len=self.max_len,
+            min_len=self.min_len,
+            num_ans_candidates=self.num_ans_candidates,
+        )
+
+        pred_qa_pairs = []
+
+        question_id = samples["question_id"]
+        gt_answers = samples["direct_answers"]
+
+        for pred_answer, ques_id, gt_answer in zip(answers, question_id, gt_answers):
+            pred_qa_pairs.append(
+                {"question_id": ques_id, "pred_ans": pred_answer, "gt_ans": gt_answer}
+            )
+
+        return pred_qa_pairs
+
+    @dist_utils.main_process
+    def _report_metrics(self, result_file, split):
+        """
+        Implementing accuracy computation for AOKVQA, see
+        https://github.com/allenai/aokvqa/blob/main/evaluation/eval_predictions.py#L45 for details.
+        """
+        # TODO add evaluation for multi-choice
+
+        results = json.load(open(result_file, "r"))
+        acc = []
+
+        for res in results:
+            if res["gt_ans"] is None:
+                # prepare test results for leaderboard evaluation
+                self._save_result_leaderboard(results)
+                return
+
+            pred = res["pred_ans"]
+            gt_ans = res["gt_ans"]
+
+            num_match = sum([pred == gt for gt in gt_ans])
+            vqa_acc = min(1.0, num_match / 3.0)
+
+            acc.append(vqa_acc)
+
+        accuracy = sum(acc) / len(acc) * 100
+        metrics = {"agg_metrics": accuracy, "acc": accuracy}
+
+        with open(
+            os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
+        ) as f:
+            f.write(json.dumps(metrics) + "\n")
+
+        logging.info(metrics)
+
+        return metrics
+
+    @dist_utils.main_process
+    def _save_result_leaderboard(self, results):
+        """
+        Saving the results in the format required for leaderboard evaluation.
+
+        [TODO] add support for multi-choice.
+        """
+        result_leaderboard = dict()
+        for res in results:
+            result_leaderboard[res["question_id"]] = {
+                "direct_answer": res["pred_ans"],
+                "multiple_choice": "",
+            }
+
+        result_file = registry.get_path("result_dir") + "_leaderboard.json"
+
+        with open(result_file, "w") as f:
+            json.dump(result_leaderboard, f)
+
+        logging.info(f"Saved results for leaderboard evaluation at {result_file}")