--- a +++ b/scripts/train.py @@ -0,0 +1,417 @@ +from collections import OrderedDict +from argparse import Namespace +import pickle +import os +import sys + +import pytorch_lightning as pl +import torch + +sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) +from sybil.utils.helpers import get_dataset +import sybil.utils.losses as losses +import sybil.utils.metrics as metrics +import sybil.utils.loading as loaders +import sybil.models.sybil as model +from sybil.parsing import parse_args + + + +class SybilLightning(pl.LightningModule): + """ + Lightning Module + Methods: + .log/.log_dict: log inputs to logger + Notes: + *_epoch_end method returns None + self can log additional data structures to logger + with self.logger.experiment.log_* + (*= 'text', 'image', 'audio', 'confusion_matrix', 'histogram') + """ + + def __init__(self, args): + super(SybilLightning, self).__init__() + if isinstance(args, dict): + args = Namespace(**args) + self.args = args + self.model = model.SybilNet(args) + self.save_prefix = "default" + self.save_hyperparameters(args) + self._list_of_metrics = [ + metrics.get_classification_metrics, + metrics.get_survival_metrics, + metrics.get_risk_metrics + ] + + def set_finetune(self, finetune_flag): + return + + def forward(self, x): + return self.model(x) + + def step(self, batch, batch_idx, optimizer_idx, log_key_prefix=""): + model_output = self(batch["x"]) + logging_dict, predictions_dict = OrderedDict(), OrderedDict() + + if "exam" in batch: + predictions_dict["exam"] = batch["exam"] + if "y" in batch: + predictions_dict["golds"] = batch["y"] + + if self.args.save_attention_scores: + attentions = {k: v for k, v in model_output.items() if "attention" in k} + predictions_dict.update(attentions) + + loss_fns = self.get_loss_functions(self.args) + loss = 0 + for loss_fn in loss_fns: + local_loss, local_log_dict, local_predictions_dict = loss_fn( + model_output, batch, self, self.args + ) + loss += local_loss + logging_dict.update(local_log_dict) + predictions_dict.update(local_predictions_dict) + logging_dict = prefix_dict(logging_dict, log_key_prefix) + predictions_dict = prefix_dict(predictions_dict, log_key_prefix) + return loss, logging_dict, predictions_dict, model_output + + def training_step(self, batch, batch_idx, optimizer_idx=None): + result = OrderedDict() + loss, logging_dict, predictions_dict, _ = self.step( + batch, batch_idx, optimizer_idx, log_key_prefix="train_" + ) + logging_dict["train_loss"] = loss.detach() + self.log_dict(logging_dict, prog_bar=False, on_step=True, on_epoch=True) + result["logs"] = logging_dict + self.log_tensor_dict(predictions_dict, prog_bar=False, logger=False) + result.update(predictions_dict) + # lightning expects 'loss' key in output dict. ow loss := None by default + result["loss"] = loss + return result + + def validation_step(self, batch, batch_idx, optimizer_idx=None): + result = OrderedDict() + loss, logging_dict, predictions_dict, _ = self.step( + batch, batch_idx, optimizer_idx, log_key_prefix="val_" + ) + logging_dict["val_loss"] = loss.detach() + self.log_dict(logging_dict, prog_bar=True, sync_dist=True) + result["logs"] = logging_dict + if self.args.accelerator == "ddp": + predictions_dict = gather_predictions_dict(predictions_dict) + self.log_tensor_dict(predictions_dict, prog_bar=False, logger=False) + result.update(predictions_dict) + return result + + def test_step(self, batch, batch_idx, optimizer_idx=None): + result = OrderedDict() + loss, logging_dict, predictions_dict, model_output = self.step( + batch, batch_idx, optimizer_idx, log_key_prefix="test_" + ) + logging_dict["{}_loss".format(self.save_prefix)] = loss.detach() + result["logs"] = logging_dict + + if self.args.accelerator == "ddp": + predictions_dict = gather_predictions_dict(predictions_dict) + + self.log_tensor_dict(predictions_dict, prog_bar=False, logger=False) + result.update(predictions_dict) + return result + + def training_epoch_end(self, outputs): + if len(outputs) == 0: + return + outputs = gather_step_outputs(outputs) + # loss already logged in progress_bar_dict (get_progress_bar_dict()), + # and logging twice creates issue + del outputs["loss"] + epoch_metrics = compute_epoch_metrics( + self._list_of_metrics, outputs, self.args, self.device, key_prefix="train_" + ) + for k, v in outputs["logs"].items(): + epoch_metrics[k] = v.mean() + self.log_dict(epoch_metrics, prog_bar=True, logger=True) + + def validation_epoch_end(self, outputs): + if len(outputs) == 0: + return + outputs = gather_step_outputs(outputs) + epoch_metrics = compute_epoch_metrics( + self._list_of_metrics, outputs, self.args, self.device, key_prefix="val_" + ) + for k, v in outputs["logs"].items(): + epoch_metrics[k] = v.mean() + self.log_dict(epoch_metrics, prog_bar=True, logger=True) + + def test_epoch_end(self, outputs): + self.save_prefix= 'test' + if len(outputs) == 0: + return + outputs = gather_step_outputs(outputs) + epoch_metrics = compute_epoch_metrics( + self._list_of_metrics, outputs, self.args, self.device, key_prefix="test_" + ) + + for k, v in outputs["logs"].items(): + epoch_metrics[k] = v.mean() + + self.log_dict(epoch_metrics, prog_bar=True, logger=True) + + # Dump metrics for use by dispatcher + metrics_dict = { + k[len(self.save_prefix) :]: v.mean().item() + for k, v in outputs.items() + if "loss" in k + } + metrics_dict.update( + { + k[len(self.save_prefix) :]: v.mean().item() + for k, v in epoch_metrics.items() + } + ) + metrics_filename = "{}.{}.metrics".format( + self.args.results_path, self.save_prefix + ) + pickle.dump(metrics_dict, open(metrics_filename, "wb")) + if self.args.save_predictions and self.global_rank == 0: + predictions_dict = { + k: v.cpu() if isinstance(v, torch.Tensor) else v + for k, v in outputs.items() + } + predictions_filename = "{}.{}.predictions".format( + self.args.results_path, self.save_prefix + ) + pickle.dump(predictions_dict, open(predictions_filename, "wb")) + + def configure_optimizers(self): + """ + Helper function to fetch optimizer based on args. + """ + params = [param for param in self.model.parameters() if param.requires_grad] + if self.args.optimizer == "adam": + optimizer = torch.optim.Adam( + params, lr=self.args.lr, weight_decay=self.args.weight_decay + ) + elif self.args.optimizer == "adagrad": + optimizer = torch.optim.Adagrad( + params, lr=self.args.lr, weight_decay=self.args.weight_decay + ) + elif self.args.optimizer == "sgd": + optimizer = torch.optim.SGD( + params, + lr=self.args.lr, + weight_decay=self.args.weight_decay, + momentum=self.args.momentum, + ) + else: + raise Exception("Optimizer {} not supported!".format(self.args.optimizer)) + + scheduler = { + "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + patience=self.args.patience, + factor=self.args.lr_decay, + mode="min" if "loss" in self.args.tuning_metric else "max", + ), + "monitor": "val_{}".format(self.args.tuning_metric), + "interval": "epoch", + "frequency": 1, + } + return [optimizer], [scheduler] + + def log_tensor_dict( + self, + output, + prog_bar=False, + logger=True, + on_step=None, + on_epoch=None, + sync_dist=False, + ): + dict_of_tensors = { + k: v.float() for k, v in output.items() if isinstance(v, torch.Tensor) + } + self.log_dict( + dict_of_tensors, + prog_bar=prog_bar, + logger=logger, + on_step=on_step, + on_epoch=on_epoch, + sync_dist=sync_dist, + ) + + def get_loss_functions(self, args): + loss_fns = [losses.get_survival_loss] + if args.use_annotations: + loss_fns.append(losses.get_annotation_loss) + + return loss_fns + + +def prefix_dict(d, prefix): + r = OrderedDict() + for k, v in d.items(): + r[prefix + k] = v + return r + + +def gather_predictions_dict(predictions): + gathered_preds = { + k: concat_all_gather(v) if isinstance(v, torch.Tensor) else v + for k, v in predictions.items() + } + return gathered_preds + + +def gather_step_outputs(outputs): + output_dict = OrderedDict() + if isinstance(outputs[-1], list): + outputs = outputs[0] + + for k in outputs[-1].keys(): + if k == "logs": + output_dict[k] = gather_step_outputs([output["logs"] for output in outputs]) + elif ( + isinstance(outputs[-1][k], torch.Tensor) and len(outputs[-1][k].shape) == 0 + ): + output_dict[k] = torch.stack([output[k] for output in outputs]) + elif isinstance(outputs[-1][k], torch.Tensor): + output_dict[k] = torch.cat([output[k] for output in outputs], dim=0) + else: + output_dict[k] = [output[k] for output in outputs] + return output_dict + + +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + output = torch.cat(tensors_gather, dim=0) + return output + + +def compute_epoch_metrics(list_of_metrics, result_dict, args, device, key_prefix=""): + stats_dict = OrderedDict() + + """ + Remove prefix from keys. For instance, convert: + val_probs -> probs for standard handling in the metric fucntions + """ + result_dict_wo_key_prefix = {} + + for k, v in result_dict.items(): + if isinstance(v, list) and isinstance(v[-1], torch.Tensor): + v = torch.cat(v, dim=-1) + if isinstance(v, torch.Tensor): + v = v.cpu().numpy() + if k == "meta": + continue + if key_prefix != "" and k.startswith(key_prefix): + k_wo_prefix = k[len(key_prefix) :] + result_dict_wo_key_prefix[k_wo_prefix] = v + else: + result_dict_wo_key_prefix[k] = v + + for k, v in result_dict["logs"].items(): + if k.startswith(key_prefix): + result_dict_wo_key_prefix[k[len(key_prefix) :]] = v + + for metric_func in list_of_metrics: + stats_wo_prefix = metric_func(result_dict_wo_key_prefix, args) + for k, v in stats_wo_prefix.items(): + stats_dict[key_prefix + k] = torch.tensor(v, device=device) + + return stats_dict + + +def train(args): + if not args.turn_off_checkpointing: + checkpoint_callback = pl.callbacks.ModelCheckpoint( + dirpath=args.save_dir, + save_top_k=1, + verbose=True, + monitor="val_{}".format(args.tuning_metric) + if args.tuning_metric is not None + else None, + save_last=True, + mode="min" + if args.tuning_metric is not None and "loss" in args.tuning_metric + else "max", + ) + args.callbacks = [checkpoint_callback] + trainer = pl.Trainer.from_argparse_args(args) + # Remove callbacks from args for safe pickling later + args.callbacks = None + args.num_nodes = trainer.num_nodes + args.num_processes = trainer.num_processes + args.world_size = args.num_nodes * args.num_processes + args.global_rank = trainer.global_rank + args.local_rank = trainer.local_rank + + train_dataset = loaders.get_train_dataset_loader( + args, get_dataset(args.dataset, "train", args) + ) + dev_dataset = loaders.get_eval_dataset_loader( + args, get_dataset(args.dataset, "dev", args), False + ) + + args.censoring_distribution = metrics.get_censoring_dist(train_dataset.dataset) + module = SybilLightning(args) + + # print args + for key, value in sorted(vars(args).items()): + print("{} -- {}".format(key.upper(), value)) + + if args.snapshot is not None: + module = module.load_from_checkpoint(checkpoint_path= args.snapshot, strict=False) + module.args = args + + trainer.fit(module, train_dataset, dev_dataset) + args.model_path = trainer.checkpoint_callback.best_model_path + print("Saving args to {}".format(args.results_path)) + pickle.dump(vars(args), open(args.results_path, "wb")) + +def test(args): + trainer = pl.Trainer.from_argparse_args(args) + # Remove callbacks from args for safe pickling later + args.callbacks = None + args.num_nodes = trainer.num_nodes + args.num_processes = trainer.num_processes + args.world_size = args.num_nodes * args.num_processes + args.global_rank = trainer.global_rank + args.local_rank = trainer.local_rank + + train_dataset = loaders.get_train_dataset_loader( + args, get_dataset(args.dataset, "train", args) + ) + test_dataset = loaders.get_eval_dataset_loader( + args, get_dataset(args.dataset, "test", args), False + ) + + args.censoring_distribution = metrics.get_censoring_dist(train_dataset.dataset) + module = SybilLightning(args) + module = module.load_from_checkpoint(checkpoint_path= args.snapshot, strict=False) + module.args = args + + # print args + for key, value in sorted(vars(args).items()): + print("{} -- {}".format(key.upper(), value)) + + trainer.test(module, test_dataset) + + print("Saving args to {}".format(args.results_path)) + pickle.dump(vars(args), open(args.results_path, "wb")) + +if __name__ == "__main__": + args = parse_args() + if args.train: + train(args) + elif args.test: + test(args)