Switch to side-by-side view

--- a
+++ b/model/lavis/runners/runner_iter.py
@@ -0,0 +1,292 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import datetime
+import logging
+import os
+import time
+
+import torch
+import torch.distributed as dist
+from model.lavis.common.dist_utils import download_cached_file, is_main_process, main_process
+from model.lavis.common.registry import registry
+from model.lavis.common.utils import is_url
+from model.lavis.datasets.data_utils import concat_datasets, reorg_datasets_by_split
+from model.lavis.runners.runner_base import RunnerBase
+from torch.utils.data.dataset import ChainDataset
+
+
+@registry.register_runner("runner_iter")
+class RunnerIter(RunnerBase):
+    """
+    Run training based on the number of iterations. This is common when
+    the training dataset size is large. Underhood logic is similar to
+    epoch-based training by considering every #iters_per_inner_epoch as an
+    inner epoch.
+
+    In iter-based runner, after every #iters_per_inner_epoch steps, we
+
+        1) do a validation epoch;
+        2) schedule the learning rate;
+        3) save the checkpoint.
+
+    We refer every #iters_per_inner_epoch steps as an inner epoch.
+    """
+
+    def __init__(self, cfg, task, model, datasets, job_id):
+        super().__init__(cfg, task, model, datasets, job_id)
+
+        self.start_iters = 0
+
+        self.max_iters = int(self.config.run_cfg.get("max_iters", -1))
+        assert self.max_iters > 0, "max_iters must be greater than 0."
+
+        self.iters_per_inner_epoch = int(
+            self.config.run_cfg.get("iters_per_inner_epoch", -1)
+        )
+        assert (
+            self.iters_per_inner_epoch > 0
+        ), "iters_per_inner_epoch must be greater than 0."
+
+    @property
+    def max_epoch(self):
+        return int(self.max_iters / self.iters_per_inner_epoch)
+
+    @property
+    def cur_epoch(self):
+        try:
+            return self.train_loader.epoch
+        except AttributeError:
+            # pipeline data (e.g. LAION) is streaming, have no concept of epoch
+            return 0
+
+    def _progress(self, cur_iters):
+        return "{}_iters={}".format(self.cur_epoch, cur_iters)
+
+    def train(self):
+        start_time = time.time()
+        best_agg_metric = 0
+        best_iters = 0
+
+        self.log_config()
+
+        # resume from checkpoint if specified
+        if not self.evaluate_only and self.resume_ckpt_path is not None:
+            self._load_checkpoint(self.resume_ckpt_path)
+
+        for start_iters in range(
+            self.start_iters, self.max_iters, self.iters_per_inner_epoch
+        ):
+            end_iters = start_iters + self.iters_per_inner_epoch
+
+            # training phase
+            if not self.evaluate_only:
+                logging.info(
+                    "Start training, max_iters={}, in total {} inner epochs.".format(
+                        self.max_iters, int(self.max_iters / self.iters_per_inner_epoch)
+                    )
+                )
+
+                train_stats = self.train_iters(self.cur_epoch, start_iters)
+                self.log_stats(split_name="train", stats=train_stats)
+
+            # evaluation phase
+            if len(self.valid_splits) > 0:
+                for split_name in self.valid_splits:
+                    logging.info("Evaluating on {}.".format(split_name))
+
+                    val_log = self.eval_epoch(
+                        split_name=split_name, cur_epoch=self._progress(end_iters)
+                    )
+                    if val_log is not None:
+                        if is_main_process():
+                            assert (
+                                "agg_metrics" in val_log
+                            ), "No agg_metrics found in validation log."
+
+                            agg_metrics = val_log["agg_metrics"]
+                            if agg_metrics > best_agg_metric and split_name == "val":
+                                best_iters, best_agg_metric = end_iters, agg_metrics
+
+                                self._save_checkpoint(end_iters, is_best=True)
+
+                            val_log.update({"best_iters": best_iters})
+                            self.log_stats(val_log, split_name)
+
+            else:
+                # if no validation split is provided, we just save the checkpoint at the end of each inner epoch.
+                if not self.evaluate_only:
+                    self._save_checkpoint(end_iters, is_best=False)
+
+            if self.evaluate_only:
+                break
+            dist.barrier()
+
+        # testing phase
+        self.evaluate(cur_epoch=self.cur_epoch)
+
+        total_time = time.time() - start_time
+        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+        logging.info("Training time {}".format(total_time_str))
+
+    def train_iters(self, epoch, start_iters):
+        # train by iterations
+        self.model.train()
+
+        return self.task.train_iters(
+            epoch=epoch,
+            start_iters=start_iters,
+            iters_per_inner_epoch=self.iters_per_inner_epoch,
+            model=self.model,
+            data_loader=self.train_loader,
+            optimizer=self.optimizer,
+            scaler=self.scaler,
+            lr_scheduler=self.lr_scheduler,
+            cuda_enabled=self.cuda_enabled,
+            log_freq=self.log_freq,
+            accum_grad_iters=self.accum_grad_iters,
+        )
+
+    @main_process
+    def _save_checkpoint(self, cur_iters, is_best=False):
+        save_obj = {
+            "model": self.unwrap_dist_model(self.model).state_dict(),
+            "optimizer": self.optimizer.state_dict(),
+            "config": self.config.to_dict(),
+            "scaler": self.scaler.state_dict() if self.scaler else None,
+            "iters": cur_iters,
+        }
+        save_to = os.path.join(
+            self.output_dir,
+            "checkpoint_{}.pth".format("best" if is_best else cur_iters),
+        )
+        logging.info("Saving checkpoint at iters {} to {}.".format(cur_iters, save_to))
+        torch.save(save_obj, save_to)
+
+    def _load_checkpoint(self, url_or_filename):
+        """
+        Resume from a checkpoint.
+        """
+        if is_url(url_or_filename):
+            cached_file = download_cached_file(
+                url_or_filename, check_hash=False, progress=True
+            )
+            checkpoint = torch.load(cached_file, map_location=self.device)
+        elif os.path.isfile(url_or_filename):
+            checkpoint = torch.load(url_or_filename, map_location=self.device)
+        else:
+            raise RuntimeError("checkpoint url or path is invalid")
+
+        state_dict = checkpoint["model"]
+        self.unwrap_dist_model(self.model).load_state_dict(state_dict)
+
+        self.optimizer.load_state_dict(checkpoint["optimizer"])
+        if self.scaler and "scaler" in checkpoint:
+            self.scaler.load_state_dict(checkpoint["scaler"])
+
+        self.start_iters = checkpoint["iters"] + 1
+        logging.info("Resume checkpoint from {}".format(url_or_filename))
+
+    @property
+    def dataloaders(self) -> dict:
+        """
+        A property to get and create dataloaders by split just in need.
+
+        If no train_dataset_ratio is provided, concatenate map-style datasets and
+        chain wds.DataPipe datasets separately. Training set becomes a tuple
+        (ConcatDataset, ChainDataset), both are optional but at least one of them is
+        required. The resultant ConcatDataset and ChainDataset will be sampled evenly.
+
+        If train_dataset_ratio is provided, create a MultiIterLoader to sample
+        each dataset by ratios during training.
+
+        Currently do not support multiple datasets for validation and test.
+
+        Returns:
+            dict: {split_name: (tuples of) dataloader}
+        """
+        if self._dataloaders is None:
+            # reoganize datasets by split and concatenate/chain if necessary
+            dataset_ratios = self.config.run_cfg.get("train_dataset_ratios", None)
+
+            if dataset_ratios is None:
+                # concatenate map-style datasets and chain wds.DataPipe datasets separately
+                # training set becomes a tuple (ConcatDataset, ChainDataset), both are
+                # optional but at least one of them is required. The resultant ConcatDataset
+                # and ChainDataset will be sampled evenly.
+                logging.info(
+                    "dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)."
+                )
+
+                datasets = reorg_datasets_by_split(self.datasets)
+                self.datasets = concat_datasets(datasets)
+
+            # print dataset statistics after concatenation/chaining
+            for split_name in self.datasets:
+                if isinstance(self.datasets[split_name], tuple) or isinstance(
+                    self.datasets[split_name], list
+                ):
+                    # mixed wds.DataPipeline and torch.utils.data.Dataset
+                    num_records = sum(
+                        [
+                            len(d)
+                            if not type(d) in [ChainDataset]
+                            else 0
+                            for d in self.datasets[split_name]
+                        ]
+                    )
+
+                else:
+                    try:
+                        # a single map-style dataset
+                        num_records = len(self.datasets[split_name])
+                    except TypeError:
+                        # a single wds.DataPipeline or ChainDataset
+                        num_records = -1
+                        logging.info(
+                            "Only a single wds.DataPipeline dataset, no __len__ attribute."
+                        )
+
+                if num_records >= 0:
+                    logging.info(
+                        "Loaded {} records for {} split from the dataset.".format(
+                            num_records, split_name
+                        )
+                    )
+
+            # create dataloaders
+            split_names = sorted(self.datasets.keys())
+
+            datasets = [self.datasets[split] for split in split_names]
+            is_trains = [split in self.train_splits for split in split_names]
+
+            batch_sizes = [
+                self.config.run_cfg.batch_size_train
+                if split == "train"
+                else self.config.run_cfg.batch_size_eval
+                for split in split_names
+            ]
+
+            collate_fns = []
+            for dataset in datasets:
+                if isinstance(dataset, tuple) or isinstance(dataset, list):
+                    collate_fns.append([getattr(d, "collater", None) for d in dataset])
+                else:
+                    collate_fns.append(getattr(dataset, "collater", None))
+
+            dataloaders = self.create_loaders(
+                datasets=datasets,
+                num_workers=self.config.run_cfg.num_workers,
+                batch_sizes=batch_sizes,
+                is_trains=is_trains,
+                collate_fns=collate_fns,
+                dataset_ratios=dataset_ratios,
+            )
+
+            self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}
+
+        return self._dataloaders