--- a +++ b/model/lavis/runners/runner_iter.py @@ -0,0 +1,292 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import datetime +import logging +import os +import time + +import torch +import torch.distributed as dist +from model.lavis.common.dist_utils import download_cached_file, is_main_process, main_process +from model.lavis.common.registry import registry +from model.lavis.common.utils import is_url +from model.lavis.datasets.data_utils import concat_datasets, reorg_datasets_by_split +from model.lavis.runners.runner_base import RunnerBase +from torch.utils.data.dataset import ChainDataset + + +@registry.register_runner("runner_iter") +class RunnerIter(RunnerBase): + """ + Run training based on the number of iterations. This is common when + the training dataset size is large. Underhood logic is similar to + epoch-based training by considering every #iters_per_inner_epoch as an + inner epoch. + + In iter-based runner, after every #iters_per_inner_epoch steps, we + + 1) do a validation epoch; + 2) schedule the learning rate; + 3) save the checkpoint. + + We refer every #iters_per_inner_epoch steps as an inner epoch. + """ + + def __init__(self, cfg, task, model, datasets, job_id): + super().__init__(cfg, task, model, datasets, job_id) + + self.start_iters = 0 + + self.max_iters = int(self.config.run_cfg.get("max_iters", -1)) + assert self.max_iters > 0, "max_iters must be greater than 0." + + self.iters_per_inner_epoch = int( + self.config.run_cfg.get("iters_per_inner_epoch", -1) + ) + assert ( + self.iters_per_inner_epoch > 0 + ), "iters_per_inner_epoch must be greater than 0." + + @property + def max_epoch(self): + return int(self.max_iters / self.iters_per_inner_epoch) + + @property + def cur_epoch(self): + try: + return self.train_loader.epoch + except AttributeError: + # pipeline data (e.g. LAION) is streaming, have no concept of epoch + return 0 + + def _progress(self, cur_iters): + return "{}_iters={}".format(self.cur_epoch, cur_iters) + + def train(self): + start_time = time.time() + best_agg_metric = 0 + best_iters = 0 + + self.log_config() + + # resume from checkpoint if specified + if not self.evaluate_only and self.resume_ckpt_path is not None: + self._load_checkpoint(self.resume_ckpt_path) + + for start_iters in range( + self.start_iters, self.max_iters, self.iters_per_inner_epoch + ): + end_iters = start_iters + self.iters_per_inner_epoch + + # training phase + if not self.evaluate_only: + logging.info( + "Start training, max_iters={}, in total {} inner epochs.".format( + self.max_iters, int(self.max_iters / self.iters_per_inner_epoch) + ) + ) + + train_stats = self.train_iters(self.cur_epoch, start_iters) + self.log_stats(split_name="train", stats=train_stats) + + # evaluation phase + if len(self.valid_splits) > 0: + for split_name in self.valid_splits: + logging.info("Evaluating on {}.".format(split_name)) + + val_log = self.eval_epoch( + split_name=split_name, cur_epoch=self._progress(end_iters) + ) + if val_log is not None: + if is_main_process(): + assert ( + "agg_metrics" in val_log + ), "No agg_metrics found in validation log." + + agg_metrics = val_log["agg_metrics"] + if agg_metrics > best_agg_metric and split_name == "val": + best_iters, best_agg_metric = end_iters, agg_metrics + + self._save_checkpoint(end_iters, is_best=True) + + val_log.update({"best_iters": best_iters}) + self.log_stats(val_log, split_name) + + else: + # if no validation split is provided, we just save the checkpoint at the end of each inner epoch. + if not self.evaluate_only: + self._save_checkpoint(end_iters, is_best=False) + + if self.evaluate_only: + break + dist.barrier() + + # testing phase + self.evaluate(cur_epoch=self.cur_epoch) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logging.info("Training time {}".format(total_time_str)) + + def train_iters(self, epoch, start_iters): + # train by iterations + self.model.train() + + return self.task.train_iters( + epoch=epoch, + start_iters=start_iters, + iters_per_inner_epoch=self.iters_per_inner_epoch, + model=self.model, + data_loader=self.train_loader, + optimizer=self.optimizer, + scaler=self.scaler, + lr_scheduler=self.lr_scheduler, + cuda_enabled=self.cuda_enabled, + log_freq=self.log_freq, + accum_grad_iters=self.accum_grad_iters, + ) + + @main_process + def _save_checkpoint(self, cur_iters, is_best=False): + save_obj = { + "model": self.unwrap_dist_model(self.model).state_dict(), + "optimizer": self.optimizer.state_dict(), + "config": self.config.to_dict(), + "scaler": self.scaler.state_dict() if self.scaler else None, + "iters": cur_iters, + } + save_to = os.path.join( + self.output_dir, + "checkpoint_{}.pth".format("best" if is_best else cur_iters), + ) + logging.info("Saving checkpoint at iters {} to {}.".format(cur_iters, save_to)) + torch.save(save_obj, save_to) + + def _load_checkpoint(self, url_or_filename): + """ + Resume from a checkpoint. + """ + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location=self.device) + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location=self.device) + else: + raise RuntimeError("checkpoint url or path is invalid") + + state_dict = checkpoint["model"] + self.unwrap_dist_model(self.model).load_state_dict(state_dict) + + self.optimizer.load_state_dict(checkpoint["optimizer"]) + if self.scaler and "scaler" in checkpoint: + self.scaler.load_state_dict(checkpoint["scaler"]) + + self.start_iters = checkpoint["iters"] + 1 + logging.info("Resume checkpoint from {}".format(url_or_filename)) + + @property + def dataloaders(self) -> dict: + """ + A property to get and create dataloaders by split just in need. + + If no train_dataset_ratio is provided, concatenate map-style datasets and + chain wds.DataPipe datasets separately. Training set becomes a tuple + (ConcatDataset, ChainDataset), both are optional but at least one of them is + required. The resultant ConcatDataset and ChainDataset will be sampled evenly. + + If train_dataset_ratio is provided, create a MultiIterLoader to sample + each dataset by ratios during training. + + Currently do not support multiple datasets for validation and test. + + Returns: + dict: {split_name: (tuples of) dataloader} + """ + if self._dataloaders is None: + # reoganize datasets by split and concatenate/chain if necessary + dataset_ratios = self.config.run_cfg.get("train_dataset_ratios", None) + + if dataset_ratios is None: + # concatenate map-style datasets and chain wds.DataPipe datasets separately + # training set becomes a tuple (ConcatDataset, ChainDataset), both are + # optional but at least one of them is required. The resultant ConcatDataset + # and ChainDataset will be sampled evenly. + logging.info( + "dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)." + ) + + datasets = reorg_datasets_by_split(self.datasets) + self.datasets = concat_datasets(datasets) + + # print dataset statistics after concatenation/chaining + for split_name in self.datasets: + if isinstance(self.datasets[split_name], tuple) or isinstance( + self.datasets[split_name], list + ): + # mixed wds.DataPipeline and torch.utils.data.Dataset + num_records = sum( + [ + len(d) + if not type(d) in [ChainDataset] + else 0 + for d in self.datasets[split_name] + ] + ) + + else: + try: + # a single map-style dataset + num_records = len(self.datasets[split_name]) + except TypeError: + # a single wds.DataPipeline or ChainDataset + num_records = -1 + logging.info( + "Only a single wds.DataPipeline dataset, no __len__ attribute." + ) + + if num_records >= 0: + logging.info( + "Loaded {} records for {} split from the dataset.".format( + num_records, split_name + ) + ) + + # create dataloaders + split_names = sorted(self.datasets.keys()) + + datasets = [self.datasets[split] for split in split_names] + is_trains = [split in self.train_splits for split in split_names] + + batch_sizes = [ + self.config.run_cfg.batch_size_train + if split == "train" + else self.config.run_cfg.batch_size_eval + for split in split_names + ] + + collate_fns = [] + for dataset in datasets: + if isinstance(dataset, tuple) or isinstance(dataset, list): + collate_fns.append([getattr(d, "collater", None) for d in dataset]) + else: + collate_fns.append(getattr(dataset, "collater", None)) + + dataloaders = self.create_loaders( + datasets=datasets, + num_workers=self.config.run_cfg.num_workers, + batch_sizes=batch_sizes, + is_trains=is_trains, + collate_fns=collate_fns, + dataset_ratios=dataset_ratios, + ) + + self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)} + + return self._dataloaders