"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import datetime
import json
import logging
import os
import time
from pathlib import Path
import torch
import torch.distributed as dist
import wandb
from model.lavis.common.dist_utils import (
download_cached_file,
get_rank,
get_world_size,
is_main_process,
main_process,
)
from model.lavis.common.registry import registry
from model.lavis.common.utils import is_url
from model.lavis.datasets.data_utils import concat_datasets, reorg_datasets_by_split
from model.lavis.datasets.datasets.dataloader_utils import (
IterLoader,
MultiIterLoader,
PrefetchLoader,
)
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data.dataset import ChainDataset
@registry.register_runner("runner_base")
class RunnerBase:
"""
A runner class to train and evaluate a model given a task and datasets.
The runner uses pytorch distributed data parallel by default. Future release
will support other distributed frameworks.
"""
def __init__(self, cfg, task, model, datasets, job_id):
self.config = cfg
self.job_id = job_id
self.run_name = self.config.run_cfg.run_name
self.task = task
self.datasets = datasets
if 'coco_caption' in self.datasets:
self.gt_map = {}
for split_name in ["train", "val", "test"]:
self.gt_map[split_name] = {int(elem['image'].split(".")[0].split("_")[-1]): "/".join(elem['caption']) if split_name != "train" else elem['caption'] for elem in self.datasets['coco_caption'][split_name].annotation}
self._model = model
self._wrapped_model = model
self._device = None
self._optimizer = None
self._scaler = None
self._dataloaders = None
self._lr_sched = None
self.start_epoch = 0
# self.setup_seeds()
self.setup_output_dir()
def generate_html_table(self, data, columns):
html = '<table border="1" cellpadding="5" cellspacing="0">'
html += "<tr>"
for col in columns:
html += f"<th>{col}</th>"
html += "</tr>"
for row in data:
html += "<tr>"
for cell in row:
html += f"<td>{cell}</td>"
html += "</tr>"
html += "</table>"
return html
@property
def device(self):
if self._device is None:
self._device = torch.device(self.config.run_cfg.device)
return self._device
@property
def use_distributed(self):
return self.config.run_cfg.distributed
@property
def model(self):
"""
A property to get the DDP-wrapped model on the device.
"""
# move model to device
if self._model.device != self.device:
self._model = self._model.to(self.device)
# distributed training wrapper
if self.use_distributed:
if self._wrapped_model is None:
self._wrapped_model = DDP(
self._model, device_ids=[self.config.run_cfg.gpu]
)
else:
self._wrapped_model = self._model
return self._wrapped_model
@property
def optimizer(self):
# TODO make optimizer class and configurations
if self._optimizer is None:
num_parameters = 0
p_wd, p_non_wd = [], []
for n, p in self.model.named_parameters():
if not p.requires_grad:
continue # frozen weights
if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
p_non_wd.append(p)
else:
p_wd.append(p)
num_parameters += p.data.nelement()
logging.info("number of trainable parameters: %d" % num_parameters)
optim_params = [
{
"params": p_wd,
"weight_decay": float(self.config.run_cfg.weight_decay),
},
{"params": p_non_wd, "weight_decay": 0},
]
beta2 = self.config.run_cfg.get("beta2", 0.999)
self._optimizer = torch.optim.AdamW(
optim_params,
lr=float(self.config.run_cfg.init_lr),
weight_decay=float(self.config.run_cfg.weight_decay),
betas=(0.9, beta2),
)
return self._optimizer
@property
def scaler(self):
amp = self.config.run_cfg.get("amp", False)
if amp:
if self._scaler is None:
self._scaler = torch.cuda.amp.GradScaler()
return self._scaler
@property
def lr_scheduler(self):
"""
A property to get and create learning rate scheduler by split just in need.
"""
if self._lr_sched is None:
lr_sched_cls = registry.get_lr_scheduler_class(self.config.run_cfg.lr_sched)
# max_epoch = self.config.run_cfg.max_epoch
max_epoch = self.max_epoch
# min_lr = self.config.run_cfg.min_lr
min_lr = self.min_lr
# init_lr = self.config.run_cfg.init_lr
init_lr = self.init_lr
# optional parameters
decay_rate = self.config.run_cfg.get("lr_decay_rate", 1.)
warmup_start_lr = self.config.run_cfg.get("warmup_lr", -1)
warmup_steps = self.config.run_cfg.get("warmup_steps", 0)
self._lr_sched = lr_sched_cls(
optimizer=self.optimizer,
max_epoch=max_epoch,
min_lr=min_lr,
init_lr=init_lr,
decay_rate=decay_rate,
warmup_start_lr=warmup_start_lr,
warmup_steps=warmup_steps,
)
return self._lr_sched
@property
def dataloaders(self) -> dict:
"""
A property to get and create dataloaders by split just in need.
If no train_dataset_ratio is provided, concatenate map-style datasets and
chain wds.DataPipe datasets separately. Training set becomes a tuple
(ConcatDataset, ChainDataset), both are optional but at least one of them is
required. The resultant ConcatDataset and ChainDataset will be sampled evenly.
If train_dataset_ratio is provided, create a MultiIterLoader to sample
each dataset by ratios during training.
Currently do not support multiple datasets for validation and test.
Returns:
dict: {split_name: (tuples of) dataloader}
"""
if self._dataloaders is None:
# reoganize datasets by split and concatenate/chain if necessary
dataset_ratios = self.config.run_cfg.get("train_dataset_ratios", None)
# concatenate map-style datasets and chain wds.DataPipe datasets separately
# training set becomes a tuple (ConcatDataset, ChainDataset), both are
# optional but at least one of them is required. The resultant ConcatDataset
# and ChainDataset will be sampled evenly.
logging.info(
"dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)."
)
datasets = reorg_datasets_by_split(self.datasets)
#self.datasets = concat_datasets(datasets)
# select first dataset for validation and test
self.datasets = {k: v[0] for k, v in datasets.items()}
# print dataset statistics after concatenation/chaining
for split_name in self.datasets:
if isinstance(self.datasets[split_name], tuple) or isinstance(
self.datasets[split_name], list
):
# mixed wds.DataPipeline and torch.utils.data.Dataset
num_records = sum(
[
len(d)
if not type(d) in [ChainDataset]
else 0
for d in self.datasets[split_name]
]
)
else:
if hasattr(self.datasets[split_name], "__len__"):
# a single map-style dataset
num_records = len(self.datasets[split_name])
else:
# a single wds.DataPipeline
num_records = -1
logging.info(
"Only a single wds.DataPipeline dataset, no __len__ attribute."
)
if num_records >= 0:
logging.info(
"Loaded {} records for {} split from the dataset.".format(
num_records, split_name
)
)
# create dataloaders
split_names = sorted(self.datasets.keys())
datasets = [self.datasets[split] for split in split_names]
is_trains = [split in self.train_splits for split in split_names]
batch_sizes = [
self.config.run_cfg.batch_size_train
if split == "train"
else self.config.run_cfg.batch_size_eval
for split in split_names
]
collate_fns = []
for dataset in datasets:
if isinstance(dataset, tuple) or isinstance(dataset, list):
collate_fns.append([getattr(d, "collater", None) for d in dataset])
else:
collate_fns.append(getattr(dataset, "collater", None))
dataloaders = self.create_loaders(
datasets=datasets,
num_workers=self.config.run_cfg.num_workers,
batch_sizes=batch_sizes,
is_trains=is_trains,
collate_fns=collate_fns,
dataset_ratios=dataset_ratios,
)
self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}
return self._dataloaders
@property
def cuda_enabled(self):
return self.device.type == "cuda"
@property
def max_epoch(self):
return int(self.config.run_cfg.max_epoch)
@property
def log_freq(self):
log_freq = self.config.run_cfg.get("log_freq", 50)
return int(log_freq)
@property
def init_lr(self):
return float(self.config.run_cfg.init_lr)
@property
def min_lr(self):
return float(self.config.run_cfg.min_lr)
@property
def accum_grad_iters(self):
return int(self.config.run_cfg.get("accum_grad_iters", 1))
@property
def valid_splits(self):
valid_splits = self.config.run_cfg.get("valid_splits", [])
if len(valid_splits) == 0:
logging.info("No validation splits found.")
return valid_splits
@property
def test_splits(self):
test_splits = self.config.run_cfg.get("test_splits", [])
return test_splits
@property
def train_splits(self):
train_splits = self.config.run_cfg.get("train_splits", [])
if len(train_splits) == 0:
logging.info("Empty train splits.")
return train_splits
@property
def evaluate_only(self):
"""
Set to True to skip training.
"""
return self.config.run_cfg.evaluate
@property
def use_dist_eval_sampler(self):
return self.config.run_cfg.get("use_dist_eval_sampler", True)
@property
def resume_ckpt_path(self):
return self.config.run_cfg.get("resume_ckpt_path", None)
@property
def train_loader(self):
train_dataloader = self.dataloaders["train"]
return train_dataloader
def setup_output_dir(self):
#lib_root = Path(registry.get_path("library_root"))
output_dir = Path("pretraining/outputs") / self.run_name
if os.path.exists(output_dir) and not self.evaluate_only: # for evaluation want to use the same output dir
output_dir = Path("pretraining/outputs") / "{}_{}".format(
self.run_name, datetime.datetime.now().strftime("%m%d_%H%M%S")
)
elif self.evaluate_only:
# replace "_eval" with "" in the output dir name
output_dir = Path("pretraining/outputs") / self.run_name.replace("_eval", "")
result_dir = output_dir / "result"
output_dir.mkdir(parents=True, exist_ok=True)
result_dir.mkdir(parents=True, exist_ok=True)
registry.register_path("result_dir", str(result_dir))
registry.register_path("output_dir", str(output_dir))
self.result_dir = result_dir
self.output_dir = output_dir
def validate(self, cur_epoch, best_agg_metric, best_epoch, wandb_run):
if len(self.valid_splits) > 0:
for split_name in self.valid_splits:
logging.info("Evaluating on {}.".format(split_name))
val_log, results, gts, loss = self.eval_epoch(
split_name=split_name, cur_epoch=cur_epoch
)
if self.evaluate_only:
# save free-text predictions and corresponding gt
# create folder
if not os.path.exists(self.output_dir / "predictions"):
os.makedirs(self.output_dir / "predictions")
os.makedirs(self.output_dir / "ground_truths")
# save predictions
with open(self.output_dir / "predictions" / "predictions_{}.txt".format(split_name), "w") as f:
for i in range(len(results)):
f.write('"' + results[i]['caption'] + '"\n')
# save ground truths
with open(self.output_dir / "ground_truths" / "ground_truths_{}.txt".format(split_name), "w") as f:
for i in [elem['image_id'] for elem in results]:
f.write('"' + gts[i][0] + '"\n')
if val_log is not None:
if is_main_process():
assert (
"agg_metrics" in val_log
), "No agg_metrics found in validation log."
agg_metrics = val_log["agg_metrics"]
if not self.evaluate_only and split_name == "val": #dont save for train_val
if agg_metrics > best_agg_metric:
best_epoch, best_agg_metric = cur_epoch, agg_metrics
self._save_checkpoint(cur_epoch, is_best=True)
logging.info("Saving best model at epoch {}".format(cur_epoch))
else:
self._save_checkpoint(cur_epoch, is_best=False, is_last=True)
val_log.update({"best_epoch": best_epoch})
self.log_stats(val_log, split_name)
data = []
for i in range(min(len(results), 5)):
data.append([str(cur_epoch), results[i]["caption"], gts[results[i]["image_id"]]])
html = '<table border="1" cellpadding="5" cellspacing="0">'
html += "<tr>"
for col in ["Epoch", "Predicted", "GT"]:
html += f"<th>{col}</th>"
html += "</tr>"
for row in data:
html += "<tr>"
for cell in row:
html += f"<td>{cell}</td>"
html += "</tr>"
html += "</table>"
try:
wandb_run.log({f"text_predictions_{split_name}": wandb.Html(html)})
except Exception as e:
pass
if loss is not None:
self.log_stats({"loss":loss}, split_name, commit=False)
if not self.evaluate_only:
if loss < best_agg_metric and split_name == "val":
best_epoch, best_agg_metric = cur_epoch, loss
self._save_checkpoint(cur_epoch, is_best=True)
else:
self._save_checkpoint(cur_epoch, is_best=False)
else:
# if no validation split is provided, we just save the checkpoint at the end of each epoch.
if not self.evaluate_only:
self._save_checkpoint(cur_epoch, is_best=False)
return best_agg_metric, best_epoch
def train(self, wandb_run):
start_time = time.time()
best_agg_metric = 0
best_epoch = 0
self.log_config()
# resume from checkpoint if specified
if not self.evaluate_only and self.resume_ckpt_path is not None:
self._load_checkpoint(self.resume_ckpt_path)
for cur_epoch in range(self.start_epoch, self.max_epoch):
custom_epochs = self.datasets['mimic_cxr']['train'].custom_epochs_per_epoch if 'mimic_cxr' in self.datasets else self.datasets['train'].custom_epochs_per_epoch
for custom_epoch in range(custom_epochs):
if 'mimic_cxr' in self.datasets: #before first epoch
self.datasets['mimic_cxr']['train'].set_custom_epoch(custom_epoch)
else:
self.datasets['train'].set_custom_epoch(custom_epoch) #resorted after creating dataloader
# training phase
if not self.evaluate_only:
logging.info("Start training")
train_stats = self.train_epoch(cur_epoch)
self.log_stats(split_name="train", stats=train_stats)
# evaluation phase
best_agg_metric, best_epoch = self.validate(cur_epoch, best_agg_metric, best_epoch, wandb_run)
if self.evaluate_only:
break
#dist.barrier()
# testing phase
test_epoch = "best" if len(self.valid_splits) > 0 else cur_epoch
self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logging.info("Training time {}".format(total_time_str))
def evaluate(self, cur_epoch="best", skip_reload=False):
test_logs = dict()
if len(self.test_splits) > 0:
for split_name in self.test_splits:
test_logs[split_name] = self.eval_epoch(
split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload
)
return test_logs
def train_epoch(self, epoch):
# train
self.model.train()
return self.task.train_epoch(
epoch=epoch,
model=self.model,
data_loader=self.train_loader,
optimizer=self.optimizer,
scaler=self.scaler,
lr_scheduler=self.lr_scheduler,
cuda_enabled=self.cuda_enabled,
log_freq=self.log_freq,
accum_grad_iters=self.accum_grad_iters,
)
@torch.no_grad()
def eval_epoch(self, split_name, cur_epoch, skip_reload=False):
"""
Evaluate the model on a given split.
Args:
split_name (str): name of the split to evaluate on.
cur_epoch (int): current epoch.
skip_reload_best (bool): whether to skip reloading the best checkpoint.
During training, we will reload the best checkpoint for validation.
During testing, we will use provided weights and skip reloading the best checkpoint .
"""
data_loader = self.dataloaders.get(split_name, None)
assert data_loader, "data_loader for split {} is None.".format(split_name)
model = self.unwrap_dist_model(self.model)
if not skip_reload and cur_epoch == "best" or self.evaluate_only:
model = self._reload_best_model(model)
model.eval()
self.task.before_evaluation(
model=model,
dataset=self.datasets[split_name],
)
results = self.task.evaluation(model, data_loader, cuda_enabled=self.cuda_enabled)
if results is not None and self.config.run_cfg.task != "image_text_pretrain_eval":
metrics, gts = data_loader.dataset.evaluator.evaluate(results)
return metrics, results, gts, None
else:
return None, None, None, results # for pre-training instead of predicting text we compute the val loss
def unwrap_dist_model(self, model):
if self.use_distributed:
return model.module
else:
return model
def create_loaders(
self,
datasets,
num_workers,
batch_sizes,
is_trains,
collate_fns,
dataset_ratios=None,
):
"""
Create dataloaders for training and validation.
"""
def _create_loader(dataset, num_workers, bsz, is_train, collate_fn):
# create a single dataloader for each split
if isinstance(dataset, ChainDataset):
# wds.WebdDataset instance are chained together
# webdataset.DataPipeline has its own sampler and collate_fn
loader = iter(
DataLoader(
dataset,
batch_size=bsz,
num_workers=num_workers,
pin_memory=True,
)
)
else:
# map-style dataset are concatenated together
# setup distributed sampler
if self.use_distributed:
sampler = DistributedSampler(
dataset,
shuffle=is_train,
num_replicas=get_world_size(),
rank=get_rank(),
)
if not self.use_dist_eval_sampler:
# e.g. retrieval evaluation
sampler = sampler if is_train else None
else:
sampler = None
loader = DataLoader(
dataset,
batch_size=bsz,
num_workers=num_workers,
pin_memory=True,
sampler=sampler,
shuffle=sampler is None and is_train,
collate_fn=collate_fn,
drop_last=True if is_train else False,
)
#loader = PrefetchLoader(loader)
if is_train:
loader = IterLoader(loader, use_distributed=self.use_distributed)
return loader
loaders = []
for dataset, bsz, is_train, collate_fn in zip(
datasets, batch_sizes, is_trains, collate_fns
):
if isinstance(dataset, list) or isinstance(dataset, tuple):
loader = MultiIterLoader(
loaders=[
_create_loader(d, num_workers, bsz, is_train, collate_fn[i])
for i, d in enumerate(dataset)
],
ratios=dataset_ratios,
)
else:
loader = _create_loader(dataset, num_workers, bsz, is_train, collate_fn)
loaders.append(loader)
return loaders
@main_process
def _save_checkpoint(self, cur_epoch, is_best=False, is_last=False):
"""
Save the checkpoint at the current epoch.
"""
model_no_ddp = self.unwrap_dist_model(self.model)
param_grad_dic = {
k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()
}
state_dict = model_no_ddp.state_dict()
for k in list(state_dict.keys()):
if k in param_grad_dic.keys() and not param_grad_dic[k]:
# delete parameters that do not require gradient
del state_dict[k]
save_obj = {
"model": state_dict,
"optimizer": self.optimizer.state_dict(),
"config": self.config.to_dict(),
"scaler": self.scaler.state_dict() if self.scaler else None,
"epoch": cur_epoch,
}
save_to = os.path.join(
self.output_dir,
"checkpoint_{}.pth".format("best" if is_best else ("last" if is_last else cur_epoch)),
)
logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to))
torch.save(save_obj, save_to)
def _reload_best_model(self, model):
"""
Load the best checkpoint for evaluation.
"""
checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth")
logging.info("Loading checkpoint from {}.".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path, map_location="cpu")
try:
model.load_state_dict(checkpoint["model"])
except RuntimeError as e:
logging.warning(
"""
Key mismatch when loading checkpoint. This is expected if only part of the model is saved.
Trying to load the model with strict=False.
"""
)
model.load_state_dict(checkpoint["model"], strict=False)
return model
def _load_checkpoint(self, url_or_filename):
"""
Resume from a checkpoint.
"""
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location=self.device)
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location=self.device)
else:
raise RuntimeError("checkpoint url or path is invalid")
state_dict = checkpoint["model"]
self.unwrap_dist_model(self.model).load_state_dict(state_dict, strict=False) #opt_model does not need to be loaded (frozen)
self.optimizer.load_state_dict(checkpoint["optimizer"])
if self.scaler and "scaler" in checkpoint:
self.scaler.load_state_dict(checkpoint["scaler"])
self.start_epoch = checkpoint["epoch"] + 1
logging.info("Resume checkpoint from {}".format(url_or_filename))
@main_process
def log_stats(self, stats, split_name, commit=True):
if isinstance(stats, dict):
log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}}
with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
f.write(json.dumps(log_stats) + "\n")
if wandb is not None:
wandb_stats = {k: float(v) for k, v in log_stats.items()}
wandb.log(wandb_stats, commit=commit)
elif isinstance(stats, list):
pass
@main_process
def log_config(self):
with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
f.write(json.dumps(self.config.to_dict(), indent=4) + "\n")