--- a +++ b/model/lavis/runners/runner_base.py @@ -0,0 +1,745 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import datetime +import json +import logging +import os +import time +from pathlib import Path + +import torch +import torch.distributed as dist +import wandb +from model.lavis.common.dist_utils import ( + download_cached_file, + get_rank, + get_world_size, + is_main_process, + main_process, +) +from model.lavis.common.registry import registry +from model.lavis.common.utils import is_url +from model.lavis.datasets.data_utils import concat_datasets, reorg_datasets_by_split +from model.lavis.datasets.datasets.dataloader_utils import ( + IterLoader, + MultiIterLoader, + PrefetchLoader, +) +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader, DistributedSampler +from torch.utils.data.dataset import ChainDataset + + +@registry.register_runner("runner_base") +class RunnerBase: + """ + A runner class to train and evaluate a model given a task and datasets. + + The runner uses pytorch distributed data parallel by default. Future release + will support other distributed frameworks. + """ + + def __init__(self, cfg, task, model, datasets, job_id): + self.config = cfg + self.job_id = job_id + self.run_name = self.config.run_cfg.run_name + + self.task = task + self.datasets = datasets + if 'coco_caption' in self.datasets: + self.gt_map = {} + for split_name in ["train", "val", "test"]: + self.gt_map[split_name] = {int(elem['image'].split(".")[0].split("_")[-1]): "/".join(elem['caption']) if split_name != "train" else elem['caption'] for elem in self.datasets['coco_caption'][split_name].annotation} + + self._model = model + + self._wrapped_model = model + self._device = None + self._optimizer = None + self._scaler = None + self._dataloaders = None + self._lr_sched = None + + self.start_epoch = 0 + + # self.setup_seeds() + self.setup_output_dir() + + def generate_html_table(self, data, columns): + html = '<table border="1" cellpadding="5" cellspacing="0">' + html += "<tr>" + for col in columns: + html += f"<th>{col}</th>" + html += "</tr>" + + for row in data: + html += "<tr>" + for cell in row: + html += f"<td>{cell}</td>" + html += "</tr>" + html += "</table>" + + return html + + @property + def device(self): + if self._device is None: + self._device = torch.device(self.config.run_cfg.device) + + return self._device + + @property + def use_distributed(self): + return self.config.run_cfg.distributed + + @property + def model(self): + """ + A property to get the DDP-wrapped model on the device. + """ + # move model to device + if self._model.device != self.device: + self._model = self._model.to(self.device) + + # distributed training wrapper + if self.use_distributed: + if self._wrapped_model is None: + self._wrapped_model = DDP( + self._model, device_ids=[self.config.run_cfg.gpu] + ) + else: + self._wrapped_model = self._model + + return self._wrapped_model + + @property + def optimizer(self): + # TODO make optimizer class and configurations + if self._optimizer is None: + num_parameters = 0 + p_wd, p_non_wd = [], [] + for n, p in self.model.named_parameters(): + if not p.requires_grad: + continue # frozen weights + if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n: + p_non_wd.append(p) + else: + p_wd.append(p) + num_parameters += p.data.nelement() + logging.info("number of trainable parameters: %d" % num_parameters) + optim_params = [ + { + "params": p_wd, + "weight_decay": float(self.config.run_cfg.weight_decay), + }, + {"params": p_non_wd, "weight_decay": 0}, + ] + beta2 = self.config.run_cfg.get("beta2", 0.999) + self._optimizer = torch.optim.AdamW( + optim_params, + lr=float(self.config.run_cfg.init_lr), + weight_decay=float(self.config.run_cfg.weight_decay), + betas=(0.9, beta2), + ) + + return self._optimizer + + @property + def scaler(self): + amp = self.config.run_cfg.get("amp", False) + + if amp: + if self._scaler is None: + self._scaler = torch.cuda.amp.GradScaler() + + return self._scaler + + @property + def lr_scheduler(self): + """ + A property to get and create learning rate scheduler by split just in need. + """ + if self._lr_sched is None: + lr_sched_cls = registry.get_lr_scheduler_class(self.config.run_cfg.lr_sched) + + # max_epoch = self.config.run_cfg.max_epoch + max_epoch = self.max_epoch + # min_lr = self.config.run_cfg.min_lr + min_lr = self.min_lr + # init_lr = self.config.run_cfg.init_lr + init_lr = self.init_lr + + # optional parameters + decay_rate = self.config.run_cfg.get("lr_decay_rate", 1.) + warmup_start_lr = self.config.run_cfg.get("warmup_lr", -1) + warmup_steps = self.config.run_cfg.get("warmup_steps", 0) + + self._lr_sched = lr_sched_cls( + optimizer=self.optimizer, + max_epoch=max_epoch, + min_lr=min_lr, + init_lr=init_lr, + decay_rate=decay_rate, + warmup_start_lr=warmup_start_lr, + warmup_steps=warmup_steps, + ) + + return self._lr_sched + + @property + def dataloaders(self) -> dict: + """ + A property to get and create dataloaders by split just in need. + + If no train_dataset_ratio is provided, concatenate map-style datasets and + chain wds.DataPipe datasets separately. Training set becomes a tuple + (ConcatDataset, ChainDataset), both are optional but at least one of them is + required. The resultant ConcatDataset and ChainDataset will be sampled evenly. + + If train_dataset_ratio is provided, create a MultiIterLoader to sample + each dataset by ratios during training. + + Currently do not support multiple datasets for validation and test. + + Returns: + dict: {split_name: (tuples of) dataloader} + """ + if self._dataloaders is None: + # reoganize datasets by split and concatenate/chain if necessary + dataset_ratios = self.config.run_cfg.get("train_dataset_ratios", None) + + # concatenate map-style datasets and chain wds.DataPipe datasets separately + # training set becomes a tuple (ConcatDataset, ChainDataset), both are + # optional but at least one of them is required. The resultant ConcatDataset + # and ChainDataset will be sampled evenly. + logging.info( + "dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)." + ) + + datasets = reorg_datasets_by_split(self.datasets) + #self.datasets = concat_datasets(datasets) + # select first dataset for validation and test + self.datasets = {k: v[0] for k, v in datasets.items()} + + # print dataset statistics after concatenation/chaining + for split_name in self.datasets: + if isinstance(self.datasets[split_name], tuple) or isinstance( + self.datasets[split_name], list + ): + # mixed wds.DataPipeline and torch.utils.data.Dataset + num_records = sum( + [ + len(d) + if not type(d) in [ChainDataset] + else 0 + for d in self.datasets[split_name] + ] + ) + + else: + if hasattr(self.datasets[split_name], "__len__"): + # a single map-style dataset + num_records = len(self.datasets[split_name]) + else: + # a single wds.DataPipeline + num_records = -1 + logging.info( + "Only a single wds.DataPipeline dataset, no __len__ attribute." + ) + + if num_records >= 0: + logging.info( + "Loaded {} records for {} split from the dataset.".format( + num_records, split_name + ) + ) + + # create dataloaders + split_names = sorted(self.datasets.keys()) + + datasets = [self.datasets[split] for split in split_names] + is_trains = [split in self.train_splits for split in split_names] + + batch_sizes = [ + self.config.run_cfg.batch_size_train + if split == "train" + else self.config.run_cfg.batch_size_eval + for split in split_names + ] + + collate_fns = [] + for dataset in datasets: + if isinstance(dataset, tuple) or isinstance(dataset, list): + collate_fns.append([getattr(d, "collater", None) for d in dataset]) + else: + collate_fns.append(getattr(dataset, "collater", None)) + + dataloaders = self.create_loaders( + datasets=datasets, + num_workers=self.config.run_cfg.num_workers, + batch_sizes=batch_sizes, + is_trains=is_trains, + collate_fns=collate_fns, + dataset_ratios=dataset_ratios, + ) + + self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)} + + return self._dataloaders + + @property + def cuda_enabled(self): + return self.device.type == "cuda" + + @property + def max_epoch(self): + return int(self.config.run_cfg.max_epoch) + + @property + def log_freq(self): + log_freq = self.config.run_cfg.get("log_freq", 50) + return int(log_freq) + + @property + def init_lr(self): + return float(self.config.run_cfg.init_lr) + + @property + def min_lr(self): + return float(self.config.run_cfg.min_lr) + + @property + def accum_grad_iters(self): + return int(self.config.run_cfg.get("accum_grad_iters", 1)) + + @property + def valid_splits(self): + valid_splits = self.config.run_cfg.get("valid_splits", []) + + if len(valid_splits) == 0: + logging.info("No validation splits found.") + + return valid_splits + + @property + def test_splits(self): + test_splits = self.config.run_cfg.get("test_splits", []) + + return test_splits + + @property + def train_splits(self): + train_splits = self.config.run_cfg.get("train_splits", []) + + if len(train_splits) == 0: + logging.info("Empty train splits.") + + return train_splits + + @property + def evaluate_only(self): + """ + Set to True to skip training. + """ + return self.config.run_cfg.evaluate + + @property + def use_dist_eval_sampler(self): + return self.config.run_cfg.get("use_dist_eval_sampler", True) + + @property + def resume_ckpt_path(self): + return self.config.run_cfg.get("resume_ckpt_path", None) + + @property + def train_loader(self): + train_dataloader = self.dataloaders["train"] + + return train_dataloader + + def setup_output_dir(self): + #lib_root = Path(registry.get_path("library_root")) + + output_dir = Path("pretraining/outputs") / self.run_name + if os.path.exists(output_dir) and not self.evaluate_only: # for evaluation want to use the same output dir + output_dir = Path("pretraining/outputs") / "{}_{}".format( + self.run_name, datetime.datetime.now().strftime("%m%d_%H%M%S") + ) + elif self.evaluate_only: + # replace "_eval" with "" in the output dir name + output_dir = Path("pretraining/outputs") / self.run_name.replace("_eval", "") + result_dir = output_dir / "result" + + output_dir.mkdir(parents=True, exist_ok=True) + result_dir.mkdir(parents=True, exist_ok=True) + + registry.register_path("result_dir", str(result_dir)) + registry.register_path("output_dir", str(output_dir)) + + self.result_dir = result_dir + self.output_dir = output_dir + + def validate(self, cur_epoch, best_agg_metric, best_epoch, wandb_run): + if len(self.valid_splits) > 0: + for split_name in self.valid_splits: + logging.info("Evaluating on {}.".format(split_name)) + + val_log, results, gts, loss = self.eval_epoch( + split_name=split_name, cur_epoch=cur_epoch + ) + if self.evaluate_only: + # save free-text predictions and corresponding gt + # create folder + if not os.path.exists(self.output_dir / "predictions"): + os.makedirs(self.output_dir / "predictions") + os.makedirs(self.output_dir / "ground_truths") + + # save predictions + with open(self.output_dir / "predictions" / "predictions_{}.txt".format(split_name), "w") as f: + for i in range(len(results)): + f.write('"' + results[i]['caption'] + '"\n') + # save ground truths + with open(self.output_dir / "ground_truths" / "ground_truths_{}.txt".format(split_name), "w") as f: + for i in [elem['image_id'] for elem in results]: + f.write('"' + gts[i][0] + '"\n') + + if val_log is not None: + if is_main_process(): + assert ( + "agg_metrics" in val_log + ), "No agg_metrics found in validation log." + + agg_metrics = val_log["agg_metrics"] + if not self.evaluate_only and split_name == "val": #dont save for train_val + if agg_metrics > best_agg_metric: + best_epoch, best_agg_metric = cur_epoch, agg_metrics + + self._save_checkpoint(cur_epoch, is_best=True) + logging.info("Saving best model at epoch {}".format(cur_epoch)) + + else: + self._save_checkpoint(cur_epoch, is_best=False, is_last=True) + + val_log.update({"best_epoch": best_epoch}) + self.log_stats(val_log, split_name) + + data = [] + for i in range(min(len(results), 5)): + data.append([str(cur_epoch), results[i]["caption"], gts[results[i]["image_id"]]]) + html = '<table border="1" cellpadding="5" cellspacing="0">' + html += "<tr>" + for col in ["Epoch", "Predicted", "GT"]: + html += f"<th>{col}</th>" + html += "</tr>" + + for row in data: + html += "<tr>" + for cell in row: + html += f"<td>{cell}</td>" + html += "</tr>" + html += "</table>" + + try: + wandb_run.log({f"text_predictions_{split_name}": wandb.Html(html)}) + except Exception as e: + pass + + if loss is not None: + self.log_stats({"loss":loss}, split_name, commit=False) + if not self.evaluate_only: + if loss < best_agg_metric and split_name == "val": + best_epoch, best_agg_metric = cur_epoch, loss + + self._save_checkpoint(cur_epoch, is_best=True) + + else: + self._save_checkpoint(cur_epoch, is_best=False) + + else: + # if no validation split is provided, we just save the checkpoint at the end of each epoch. + if not self.evaluate_only: + self._save_checkpoint(cur_epoch, is_best=False) + + return best_agg_metric, best_epoch + + def train(self, wandb_run): + start_time = time.time() + best_agg_metric = 0 + best_epoch = 0 + + self.log_config() + + # resume from checkpoint if specified + if not self.evaluate_only and self.resume_ckpt_path is not None: + self._load_checkpoint(self.resume_ckpt_path) + + for cur_epoch in range(self.start_epoch, self.max_epoch): + custom_epochs = self.datasets['mimic_cxr']['train'].custom_epochs_per_epoch if 'mimic_cxr' in self.datasets else self.datasets['train'].custom_epochs_per_epoch + for custom_epoch in range(custom_epochs): + if 'mimic_cxr' in self.datasets: #before first epoch + self.datasets['mimic_cxr']['train'].set_custom_epoch(custom_epoch) + else: + self.datasets['train'].set_custom_epoch(custom_epoch) #resorted after creating dataloader + # training phase + if not self.evaluate_only: + logging.info("Start training") + train_stats = self.train_epoch(cur_epoch) + self.log_stats(split_name="train", stats=train_stats) + + # evaluation phase + best_agg_metric, best_epoch = self.validate(cur_epoch, best_agg_metric, best_epoch, wandb_run) + + if self.evaluate_only: + break + + #dist.barrier() + + # testing phase + test_epoch = "best" if len(self.valid_splits) > 0 else cur_epoch + self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logging.info("Training time {}".format(total_time_str)) + + def evaluate(self, cur_epoch="best", skip_reload=False): + test_logs = dict() + + if len(self.test_splits) > 0: + for split_name in self.test_splits: + test_logs[split_name] = self.eval_epoch( + split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload + ) + + return test_logs + + def train_epoch(self, epoch): + # train + self.model.train() + + return self.task.train_epoch( + epoch=epoch, + model=self.model, + data_loader=self.train_loader, + optimizer=self.optimizer, + scaler=self.scaler, + lr_scheduler=self.lr_scheduler, + cuda_enabled=self.cuda_enabled, + log_freq=self.log_freq, + accum_grad_iters=self.accum_grad_iters, + ) + + @torch.no_grad() + def eval_epoch(self, split_name, cur_epoch, skip_reload=False): + """ + Evaluate the model on a given split. + + Args: + split_name (str): name of the split to evaluate on. + cur_epoch (int): current epoch. + skip_reload_best (bool): whether to skip reloading the best checkpoint. + During training, we will reload the best checkpoint for validation. + During testing, we will use provided weights and skip reloading the best checkpoint . + """ + data_loader = self.dataloaders.get(split_name, None) + assert data_loader, "data_loader for split {} is None.".format(split_name) + + model = self.unwrap_dist_model(self.model) + if not skip_reload and cur_epoch == "best" or self.evaluate_only: + model = self._reload_best_model(model) + model.eval() + + self.task.before_evaluation( + model=model, + dataset=self.datasets[split_name], + ) + + results = self.task.evaluation(model, data_loader, cuda_enabled=self.cuda_enabled) + + if results is not None and self.config.run_cfg.task != "image_text_pretrain_eval": + metrics, gts = data_loader.dataset.evaluator.evaluate(results) + return metrics, results, gts, None + + else: + return None, None, None, results # for pre-training instead of predicting text we compute the val loss + + + def unwrap_dist_model(self, model): + if self.use_distributed: + return model.module + else: + return model + + def create_loaders( + self, + datasets, + num_workers, + batch_sizes, + is_trains, + collate_fns, + dataset_ratios=None, + ): + """ + Create dataloaders for training and validation. + """ + + def _create_loader(dataset, num_workers, bsz, is_train, collate_fn): + # create a single dataloader for each split + if isinstance(dataset, ChainDataset): + # wds.WebdDataset instance are chained together + # webdataset.DataPipeline has its own sampler and collate_fn + loader = iter( + DataLoader( + dataset, + batch_size=bsz, + num_workers=num_workers, + pin_memory=True, + ) + ) + else: + # map-style dataset are concatenated together + # setup distributed sampler + if self.use_distributed: + sampler = DistributedSampler( + dataset, + shuffle=is_train, + num_replicas=get_world_size(), + rank=get_rank(), + ) + if not self.use_dist_eval_sampler: + # e.g. retrieval evaluation + sampler = sampler if is_train else None + else: + sampler = None + + loader = DataLoader( + dataset, + batch_size=bsz, + num_workers=num_workers, + pin_memory=True, + sampler=sampler, + shuffle=sampler is None and is_train, + collate_fn=collate_fn, + drop_last=True if is_train else False, + ) + #loader = PrefetchLoader(loader) + + if is_train: + loader = IterLoader(loader, use_distributed=self.use_distributed) + + return loader + + loaders = [] + + for dataset, bsz, is_train, collate_fn in zip( + datasets, batch_sizes, is_trains, collate_fns + ): + if isinstance(dataset, list) or isinstance(dataset, tuple): + loader = MultiIterLoader( + loaders=[ + _create_loader(d, num_workers, bsz, is_train, collate_fn[i]) + for i, d in enumerate(dataset) + ], + ratios=dataset_ratios, + ) + else: + loader = _create_loader(dataset, num_workers, bsz, is_train, collate_fn) + + loaders.append(loader) + + return loaders + + @main_process + def _save_checkpoint(self, cur_epoch, is_best=False, is_last=False): + """ + Save the checkpoint at the current epoch. + """ + model_no_ddp = self.unwrap_dist_model(self.model) + param_grad_dic = { + k: v.requires_grad for (k, v) in model_no_ddp.named_parameters() + } + state_dict = model_no_ddp.state_dict() + for k in list(state_dict.keys()): + if k in param_grad_dic.keys() and not param_grad_dic[k]: + # delete parameters that do not require gradient + del state_dict[k] + save_obj = { + "model": state_dict, + "optimizer": self.optimizer.state_dict(), + "config": self.config.to_dict(), + "scaler": self.scaler.state_dict() if self.scaler else None, + "epoch": cur_epoch, + } + save_to = os.path.join( + self.output_dir, + "checkpoint_{}.pth".format("best" if is_best else ("last" if is_last else cur_epoch)), + ) + logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to)) + torch.save(save_obj, save_to) + + def _reload_best_model(self, model): + """ + Load the best checkpoint for evaluation. + """ + checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth") + + logging.info("Loading checkpoint from {}.".format(checkpoint_path)) + checkpoint = torch.load(checkpoint_path, map_location="cpu") + try: + model.load_state_dict(checkpoint["model"]) + except RuntimeError as e: + logging.warning( + """ + Key mismatch when loading checkpoint. This is expected if only part of the model is saved. + Trying to load the model with strict=False. + """ + ) + model.load_state_dict(checkpoint["model"], strict=False) + return model + + def _load_checkpoint(self, url_or_filename): + """ + Resume from a checkpoint. + """ + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location=self.device) + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location=self.device) + else: + raise RuntimeError("checkpoint url or path is invalid") + + state_dict = checkpoint["model"] + self.unwrap_dist_model(self.model).load_state_dict(state_dict, strict=False) #opt_model does not need to be loaded (frozen) + + self.optimizer.load_state_dict(checkpoint["optimizer"]) + if self.scaler and "scaler" in checkpoint: + self.scaler.load_state_dict(checkpoint["scaler"]) + + self.start_epoch = checkpoint["epoch"] + 1 + logging.info("Resume checkpoint from {}".format(url_or_filename)) + + @main_process + def log_stats(self, stats, split_name, commit=True): + if isinstance(stats, dict): + log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}} + with open(os.path.join(self.output_dir, "log.txt"), "a") as f: + f.write(json.dumps(log_stats) + "\n") + if wandb is not None: + wandb_stats = {k: float(v) for k, v in log_stats.items()} + wandb.log(wandb_stats, commit=commit) + + elif isinstance(stats, list): + pass + + @main_process + def log_config(self): + with open(os.path.join(self.output_dir, "log.txt"), "a") as f: + f.write(json.dumps(self.config.to_dict(), indent=4) + "\n") \ No newline at end of file