--- a
+++ b/lavis/tasks/captioning.py
@@ -0,0 +1,142 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import json
+import os
+
+from lavis.common.dist_utils import main_process
+from lavis.common.registry import registry
+from lavis.tasks.base_task import BaseTask
+
+
+@registry.register_task("captioning")
+class CaptionTask(BaseTask):
+    def __init__(self, num_beams, max_len, min_len, evaluate, report_metric=True):
+        super().__init__()
+
+        self.num_beams = num_beams
+        self.max_len = max_len
+        self.min_len = min_len
+        self.evaluate = evaluate
+
+        self.report_metric = report_metric
+
+    @classmethod
+    def setup_task(cls, cfg):
+        run_cfg = cfg.run_cfg
+
+        num_beams = run_cfg.num_beams
+        max_len = run_cfg.max_len
+        min_len = run_cfg.min_len
+        evaluate = run_cfg.evaluate
+
+        report_metric = run_cfg.get("report_metric", True)
+
+        return cls(
+            num_beams=num_beams,
+            max_len=max_len,
+            min_len=min_len,
+            evaluate=evaluate,
+            report_metric=report_metric,
+        )
+
+    def valid_step(self, model, samples):
+        results = []
+
+        # run_cfg = slf.cfg.run_cfg
+        captions = model.generate(
+            samples,
+            use_nucleus_sampling=False,
+            num_beams=self.num_beams,
+            max_length=self.max_len,
+            min_length=self.min_len,
+        )
+
+        img_ids = samples["image_id"]
+        for caption, img_id in zip(captions, img_ids):
+            results.append({"caption": caption, "image_id": int(img_id)})
+
+        return results
+
+    def after_evaluation(self, val_result, split_name, epoch, **kwargs):
+        eval_result_file = self.save_result(
+            result=val_result,
+            result_dir=registry.get_path("result_dir"),
+            filename="{}_epoch{}".format(split_name, epoch),
+            remove_duplicate="image_id",
+        )
+
+        if self.report_metric:
+            metrics = self._report_metrics(
+                eval_result_file=eval_result_file, split_name=split_name
+            )
+        else:
+            metrics = {"agg_metrics": 0.0}
+
+        return metrics
+
+    @main_process
+    def _report_metrics(self, eval_result_file, split_name):
+
+        # TODO better way to define this
+        coco_gt_root = os.path.join(registry.get_path("cache_root"), "coco_gt")
+        coco_val = coco_caption_eval(coco_gt_root, eval_result_file, split_name)
+
+        agg_metrics = coco_val.eval["CIDEr"] + coco_val.eval["Bleu_4"]
+        log_stats = {split_name: {k: v for k, v in coco_val.eval.items()}}
+
+        with open(
+            os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
+        ) as f:
+            f.write(json.dumps(log_stats) + "\n")
+
+        coco_res = {k: v for k, v in coco_val.eval.items()}
+        coco_res["agg_metrics"] = agg_metrics
+
+        return coco_res
+
+
+# TODO better structure for this.
+from pycocoevalcap.eval import COCOEvalCap
+from pycocotools.coco import COCO
+from torchvision.datasets.utils import download_url
+
+
+def coco_caption_eval(coco_gt_root, results_file, split):
+    urls = {
+        "val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json",
+        "test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json",
+    }
+    filenames = {
+        "val": "coco_karpathy_val_gt.json",
+        "test": "coco_karpathy_test_gt.json",
+    }
+
+    download_url(urls[split], coco_gt_root)
+    annotation_file = os.path.join(coco_gt_root, filenames[split])
+
+    # create coco object and coco_result object
+    coco = COCO(annotation_file)
+    coco_result = coco.loadRes(results_file)
+
+    # create coco_eval object by taking coco and coco_result
+    coco_eval = COCOEvalCap(coco, coco_result)
+
+    # evaluate on a subset of images by setting
+    # coco_eval.params['image_id'] = coco_result.getImgIds()
+    # please remove this line when evaluating the full validation set
+    # coco_eval.params['image_id'] = coco_result.getImgIds()
+
+    # evaluate results
+    # SPICE will take a few minutes the first time, but speeds up due to caching
+    coco_eval.evaluate()
+
+    # print output evaluation scores
+    for metric, score in coco_eval.eval.items():
+        print(f"{metric}: {score:.3f}")
+
+    return coco_eval