Diff of /lavis/tasks/vqa.py [000000] .. [dc40d0]

Switch to unified view

a b/lavis/tasks/vqa.py
1
"""
2
 Copyright (c) 2022, salesforce.com, inc.
3
 All rights reserved.
4
 SPDX-License-Identifier: BSD-3-Clause
5
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
"""
7
8
import logging
9
import json
10
import os
11
12
import lavis.common.dist_utils as dist_utils
13
from lavis.common.registry import registry
14
from lavis.common.vqa_tools.vqa import VQA
15
from lavis.common.vqa_tools.vqa_eval import VQAEval
16
from lavis.tasks.base_task import BaseTask
17
18
19
@registry.register_task("vqa")
20
class VQATask(BaseTask):
21
    def __init__(
22
        self,
23
        num_beams,
24
        max_len,
25
        min_len,
26
        evaluate,
27
        num_ans_candidates,
28
        inference_method="rank",
29
        prompt="",
30
    ):
31
        super().__init__()
32
33
        self.num_beams = num_beams
34
        self.max_len = max_len
35
        self.min_len = min_len
36
37
        self.evaluate = evaluate
38
        self.inference_method = inference_method
39
        self.num_ans_candidates = num_ans_candidates
40
        self.prompt = prompt
41
42
        self.answer_list = None
43
44
        self.ques_files = dict()
45
        self.anno_files = dict()
46
47
    @classmethod
48
    def setup_task(cls, cfg):
49
        run_cfg = cfg.run_cfg
50
51
        num_beams = run_cfg.get("num_beams", 3)
52
        max_len = run_cfg.get("max_len", 10)
53
        min_len = run_cfg.get("min_len", 1)
54
55
        evaluate = run_cfg.get("evaluate", False)
56
57
        inference_method = run_cfg.get("inference_method", "rank")
58
        num_ans_candidates = run_cfg.get("num_ans_candidates", 128)
59
        prompt = run_cfg.get("prompt", "")
60
61
        return cls(
62
            num_beams=num_beams,
63
            max_len=max_len,
64
            min_len=min_len,
65
            evaluate=evaluate,
66
            num_ans_candidates=num_ans_candidates,
67
            inference_method=inference_method,
68
            prompt=prompt,
69
        )
70
71
    def build_datasets(self, cfg):
72
        datasets = super().build_datasets(cfg)
73
74
        # get question file, annotation file and anwser list in COCO format
75
        for dataset in datasets.values():
76
            for split in dataset:
77
                if (
78
                    hasattr(dataset[split], "coco_fmt_qust_file")
79
                    and dataset[split].coco_fmt_qust_file is not None
80
                ):
81
                    self.ques_files[split] = dataset[split].coco_fmt_qust_file
82
                    self.anno_files[split] = dataset[split].coco_fmt_anno_file
83
84
                try:
85
                    self.answer_list = dataset[split].answer_list
86
                except AttributeError:
87
                    # if answer_list is not provided, then set it to None
88
                    pass
89
90
        if len(self.ques_files) > 0:
91
            assert len(self.ques_files) == len(
92
                self.anno_files
93
            ), "Only support one split for evaluation."
94
95
        return datasets
96
97
    def valid_step(self, model, samples):
98
        answers = model.predict_answers(
99
            samples=samples,
100
            answer_list=self.answer_list,
101
            inference_method=self.inference_method,
102
            num_beams=self.num_beams,
103
            max_len=self.max_len,
104
            min_len=self.min_len,
105
            num_ans_candidates=self.num_ans_candidates,
106
            prompt=self.prompt,
107
        )
108
        pred_qa_pairs = []
109
110
        question_id = samples["text_output"]
111
        for ques, answer, ques_id in zip(samples['text_input'], answers, question_id):
112
            # ques_id = int(ques_id.item())
113
            pred_qa_pairs.append({"question": ques, "question_ans": ques_id, "predict_ans": answer})
114
            print("Question: ", ques)
115
            print("Predict_ans: ", answer)
116
            print("Question_ans: ", ques_id)
117
            print("####################")
118
            
119
120
        return pred_qa_pairs
121
122
    def after_evaluation(self, val_result, split_name, **kwargs):
123
        result_file = self.save_result(
124
            val_result,
125
            result_dir=registry.get_path("result_dir"),
126
            filename=f"{split_name}_vqa_result",
127
            remove_duplicate="question_id",
128
        )
129
130
        metrics = self._report_metrics(result_file=result_file, split=split_name)
131
132
        return metrics
133
134
    @dist_utils.main_process
135
    def _report_metrics(self, result_file, split):
136
        """
137
        Use official VQA evaluation script to report metrics.
138
        """
139
        metrics = {}
140
141
        if split in self.ques_files and split in self.anno_files:
142
            vqa = VQA(self.anno_files[split], self.ques_files[split])
143
            vqa_result = vqa.loadRes(
144
                resFile=result_file, quesFile=self.ques_files[split]
145
            )
146
147
            # create vqaEval object by taking vqa and vqaRes
148
            # n is precision of accuracy (number of places after decimal), default is 2
149
            vqa_scorer = VQAEval(vqa, vqa_result, n=2)
150
            logging.info("Start VQA evaluation.")
151
            vqa_scorer.evaluate()
152
153
            # print accuracies
154
            overall_acc = vqa_scorer.accuracy["overall"]
155
            metrics["agg_metrics"] = overall_acc
156
157
            logging.info("Overall Accuracy is: %.02f\n" % overall_acc)
158
            logging.info("Per Answer Type Accuracy is the following:")
159
160
            for ans_type in vqa_scorer.accuracy["perAnswerType"]:
161
                logging.info(
162
                    "%s : %.02f"
163
                    % (ans_type, vqa_scorer.accuracy["perAnswerType"][ans_type])
164
                )
165
                metrics[ans_type] = vqa_scorer.accuracy["perAnswerType"][ans_type]
166
167
            with open(
168
                os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
169
            ) as f:
170
                f.write(json.dumps(metrics) + "\n")
171
172
        return metrics
173
174
@registry.register_task("gqa")
175
class GQATask(VQATask):
176
    def valid_step(self, model, samples):
177
        answers = model.predict_answers(
178
            samples=samples,
179
            answer_list=self.answer_list,
180
            inference_method=self.inference_method,
181
            num_beams=self.num_beams,
182
            max_len=self.max_len,
183
            min_len=self.min_len,
184
            num_ans_candidates=self.num_ans_candidates,
185
            prompt=self.prompt,
186
        )
187
        pred_qa_pairs = []
188
189
        question_id = samples["question_id"]
190
        gt_answers = samples["answer"]
191
        
192
        for answer, ques_id, gt_answer in zip(answers, question_id, gt_answers):
193
            ques_id = int(ques_id.item())
194
            pred_qa_pairs.append({"question_id": ques_id, "pred_ans": answer, "gt_ans": gt_answer})
195
196
        return pred_qa_pairs
197
        
198
    @dist_utils.main_process
199
    def _report_metrics(self, result_file, split):
200
        """
201
        TODO: add other evaluation metrics for GQA
202
        """
203
204
        results = json.load(open(result_file, "r"))
205
        acc = []
206
        vqa_tool = VQAEval()
207
208
        for res in results:
209
            if res["gt_ans"] is None:
210
                # prepare test results for leaderboard evaluation
211
                self._save_result_leaderboard(results)
212
                return
213
214
            gt_ans = res["gt_ans"]
215
            pred = res["pred_ans"]
216
217
            # if self.inference_method == "generate":
218
            pred = vqa_tool.processPunctuation(pred)
219
            pred = vqa_tool.processDigitArticle(pred)
220
221
            vqa_acc = 1 if pred == gt_ans else 0
222
223
            acc.append(vqa_acc)
224
225
        accuracy = sum(acc) / len(acc) * 100
226
        metrics = {"agg_metrics": accuracy, "acc": accuracy}
227
228
        with open(
229
            os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
230
        ) as f:
231
            f.write(json.dumps(metrics) + "\n")
232
233
        logging.info(metrics)
234
235
        return metrics
236
        
237
238
@registry.register_task("aok_vqa")
239
class AOKVQATask(VQATask):
240
    def valid_step(self, model, samples):
241
        answers = model.predict_answers(
242
            samples=samples,
243
            answer_list=self.answer_list,
244
            inference_method=self.inference_method,
245
            num_beams=self.num_beams,
246
            max_len=self.max_len,
247
            min_len=self.min_len,
248
            num_ans_candidates=self.num_ans_candidates,
249
        )
250
251
        pred_qa_pairs = []
252
253
        question_id = samples["question_id"]
254
        gt_answers = samples["direct_answers"]
255
256
        for pred_answer, ques_id, gt_answer in zip(answers, question_id, gt_answers):
257
            pred_qa_pairs.append(
258
                {"question_id": ques_id, "pred_ans": pred_answer, "gt_ans": gt_answer}
259
            )
260
261
        return pred_qa_pairs
262
263
    @dist_utils.main_process
264
    def _report_metrics(self, result_file, split):
265
        """
266
        Implementing accuracy computation for AOKVQA, see
267
        https://github.com/allenai/aokvqa/blob/main/evaluation/eval_predictions.py#L45 for details.
268
        """
269
        # TODO add evaluation for multi-choice
270
271
        results = json.load(open(result_file, "r"))
272
        acc = []
273
274
        for res in results:
275
            if res["gt_ans"] is None:
276
                # prepare test results for leaderboard evaluation
277
                self._save_result_leaderboard(results)
278
                return
279
280
            pred = res["pred_ans"]
281
            gt_ans = res["gt_ans"]
282
283
            num_match = sum([pred == gt for gt in gt_ans])
284
            vqa_acc = min(1.0, num_match / 3.0)
285
286
            acc.append(vqa_acc)
287
288
        accuracy = sum(acc) / len(acc) * 100
289
        metrics = {"agg_metrics": accuracy, "acc": accuracy}
290
291
        with open(
292
            os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
293
        ) as f:
294
            f.write(json.dumps(metrics) + "\n")
295
296
        logging.info(metrics)
297
298
        return metrics
299
300
    @dist_utils.main_process
301
    def _save_result_leaderboard(self, results):
302
        """
303
        Saving the results in the format required for leaderboard evaluation.
304
305
        [TODO] add support for multi-choice.
306
        """
307
        result_leaderboard = dict()
308
        for res in results:
309
            result_leaderboard[res["question_id"]] = {
310
                "direct_answer": res["pred_ans"],
311
                "multiple_choice": "",
312
            }
313
314
        result_file = registry.get_path("result_dir") + "_leaderboard.json"
315
316
        with open(result_file, "w") as f:
317
            json.dump(result_leaderboard, f)
318
319
        logging.info(f"Saved results for leaderboard evaluation at {result_file}")