Diff of /lavis/tasks/captioning.py [000000] .. [dc40d0]

Switch to unified view

a b/lavis/tasks/captioning.py
1
"""
2
 Copyright (c) 2022, salesforce.com, inc.
3
 All rights reserved.
4
 SPDX-License-Identifier: BSD-3-Clause
5
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
"""
7
8
import json
9
import os
10
11
from lavis.common.dist_utils import main_process
12
from lavis.common.registry import registry
13
from lavis.tasks.base_task import BaseTask
14
15
16
@registry.register_task("captioning")
17
class CaptionTask(BaseTask):
18
    def __init__(self, num_beams, max_len, min_len, evaluate, report_metric=True):
19
        super().__init__()
20
21
        self.num_beams = num_beams
22
        self.max_len = max_len
23
        self.min_len = min_len
24
        self.evaluate = evaluate
25
26
        self.report_metric = report_metric
27
28
    @classmethod
29
    def setup_task(cls, cfg):
30
        run_cfg = cfg.run_cfg
31
32
        num_beams = run_cfg.num_beams
33
        max_len = run_cfg.max_len
34
        min_len = run_cfg.min_len
35
        evaluate = run_cfg.evaluate
36
37
        report_metric = run_cfg.get("report_metric", True)
38
39
        return cls(
40
            num_beams=num_beams,
41
            max_len=max_len,
42
            min_len=min_len,
43
            evaluate=evaluate,
44
            report_metric=report_metric,
45
        )
46
47
    def valid_step(self, model, samples):
48
        results = []
49
50
        # run_cfg = slf.cfg.run_cfg
51
        captions = model.generate(
52
            samples,
53
            use_nucleus_sampling=False,
54
            num_beams=self.num_beams,
55
            max_length=self.max_len,
56
            min_length=self.min_len,
57
        )
58
59
        img_ids = samples["image_id"]
60
        for caption, img_id in zip(captions, img_ids):
61
            results.append({"caption": caption, "image_id": int(img_id)})
62
63
        return results
64
65
    def after_evaluation(self, val_result, split_name, epoch, **kwargs):
66
        eval_result_file = self.save_result(
67
            result=val_result,
68
            result_dir=registry.get_path("result_dir"),
69
            filename="{}_epoch{}".format(split_name, epoch),
70
            remove_duplicate="image_id",
71
        )
72
73
        if self.report_metric:
74
            metrics = self._report_metrics(
75
                eval_result_file=eval_result_file, split_name=split_name
76
            )
77
        else:
78
            metrics = {"agg_metrics": 0.0}
79
80
        return metrics
81
82
    @main_process
83
    def _report_metrics(self, eval_result_file, split_name):
84
85
        # TODO better way to define this
86
        coco_gt_root = os.path.join(registry.get_path("cache_root"), "coco_gt")
87
        coco_val = coco_caption_eval(coco_gt_root, eval_result_file, split_name)
88
89
        agg_metrics = coco_val.eval["CIDEr"] + coco_val.eval["Bleu_4"]
90
        log_stats = {split_name: {k: v for k, v in coco_val.eval.items()}}
91
92
        with open(
93
            os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
94
        ) as f:
95
            f.write(json.dumps(log_stats) + "\n")
96
97
        coco_res = {k: v for k, v in coco_val.eval.items()}
98
        coco_res["agg_metrics"] = agg_metrics
99
100
        return coco_res
101
102
103
# TODO better structure for this.
104
from pycocoevalcap.eval import COCOEvalCap
105
from pycocotools.coco import COCO
106
from torchvision.datasets.utils import download_url
107
108
109
def coco_caption_eval(coco_gt_root, results_file, split):
110
    urls = {
111
        "val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json",
112
        "test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json",
113
    }
114
    filenames = {
115
        "val": "coco_karpathy_val_gt.json",
116
        "test": "coco_karpathy_test_gt.json",
117
    }
118
119
    download_url(urls[split], coco_gt_root)
120
    annotation_file = os.path.join(coco_gt_root, filenames[split])
121
122
    # create coco object and coco_result object
123
    coco = COCO(annotation_file)
124
    coco_result = coco.loadRes(results_file)
125
126
    # create coco_eval object by taking coco and coco_result
127
    coco_eval = COCOEvalCap(coco, coco_result)
128
129
    # evaluate on a subset of images by setting
130
    # coco_eval.params['image_id'] = coco_result.getImgIds()
131
    # please remove this line when evaluating the full validation set
132
    # coco_eval.params['image_id'] = coco_result.getImgIds()
133
134
    # evaluate results
135
    # SPICE will take a few minutes the first time, but speeds up due to caching
136
    coco_eval.evaluate()
137
138
    # print output evaluation scores
139
    for metric, score in coco_eval.eval.items():
140
        print(f"{metric}: {score:.3f}")
141
142
    return coco_eval