--- a +++ b/custom_byol_bolts.py @@ -0,0 +1,481 @@ +import math +from argparse import ArgumentParser +from copy import deepcopy +from typing import Any + +import pytorch_lightning as pl +import torch +import torch.nn as nn +import torch.nn.functional as F +from pytorch_lightning import seed_everything +from torch.optim import Adam + +from pl_bolts.models.self_supervised import BYOL +# from pl_bolts.callbacks.self_supervised import BYOLMAWeightUpdate +from pl_bolts.optimizers.lars_scheduling import LARSWrapper +from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR + +from models.resnet_simclr import ResNetSimCLR +import re + +import time + +import yaml +import logging +import os +from clinical_ts.simclr_dataset_wrapper import SimCLRDataSetWrapper +from clinical_ts.create_logger import create_logger +import pickle +from pytorch_lightning import Trainer, seed_everything + +from torch import nn +from torch.nn import functional as F +from online_evaluator import SSLOnlineEvaluator +from ecg_datamodule import ECGDataModule +from pytorch_lightning.loggers import TensorBoardLogger +import pdb + +logger = create_logger(__name__) +method="byol" +def mean(res, key1, key2=None): + if key2 is not None: + return torch.stack([x[key1][key2] for x in res]).mean() + return torch.stack([x[key1] for x in res if type(x) == dict and key1 in x.keys()]).mean() + +class MLP(nn.Module): + def __init__(self, input_dim=512, hidden_size=4096, output_dim=256): + super().__init__() + self.output_dim = output_dim + self.input_dim = input_dim + self.model = nn.Sequential( + nn.Linear(input_dim, hidden_size, bias=False), + nn.BatchNorm1d(hidden_size), + nn.ReLU(inplace=True), + nn.Linear(hidden_size, output_dim, bias=True)) + + def forward(self, x): + x = self.model(x) + return x + + +class SiameseArm(nn.Module): + def __init__(self, encoder=None, out_dim=128, hidden_size=512, projector_dim=512): + super().__init__() + + if encoder is None: + encoder = torchvision_ssl_encoder('resnet50') + # Encoder + self.encoder = encoder + # Pooler + self.pooler = nn.AdaptiveAvgPool2d((1, 1)) + # Projector + projector_dim = encoder.l1.in_features + self.projector = MLP( + input_dim=projector_dim, hidden_size=hidden_size, output_dim=out_dim) + # Predictor + self.predictor = MLP( + input_dim=out_dim, hidden_size=hidden_size, output_dim=out_dim) + + def forward(self, x): + y = self.encoder(x)[0] + y = y.view(y.size(0), -1) + z = self.projector(y) + h = self.predictor(z) + return y, z, h + + +class BYOLMAWeightUpdate(pl.Callback): + def __init__(self, initial_tau=0.996): + """ + Args: + initial_tau: starting tau. Auto-updates with every training step + """ + super().__init__() + self.initial_tau = initial_tau + self.current_tau = initial_tau + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + # get networks + online_net = pl_module.online_network + target_net = pl_module.target_network + + # update weights + self.update_weights(online_net, target_net) + + # update tau after + self.current_tau = self.update_tau(pl_module, trainer) + + def update_tau(self, pl_module, trainer): + max_steps = len(trainer.train_dataloader) * trainer.max_epochs + tau = 1 - (1 - self.initial_tau) * (math.cos(math.pi * + pl_module.global_step / max_steps) + 1) / 2 + return tau + + def update_weights(self, online_net, target_net): + # apply MA weight update + for (name, online_p), (_, target_p) in zip(online_net.named_parameters(), target_net.named_parameters()): + if 'weight' in name: + target_p.data = self.current_tau * target_p.data + \ + (1 - self.current_tau) * online_p.data + + +class CustomBYOL(pl.LightningModule): + def __init__(self, + num_classes=5, + learning_rate: float = 0.2, + weight_decay: float = 1.5e-6, + input_height: int = 32, + batch_size: int = 32, + num_workers: int = 0, + warmup_epochs: int = 10, + max_epochs: int = 1000, + config=None, + transformations=None, + **kwargs): + """ + Args: + datamodule: The datamodule + learning_rate: the learning rate + weight_decay: optimizer weight decay + input_height: image input height + batch_size: the batch size + num_workers: number of workers + warmup_epochs: num of epochs for scheduler warm up + max_epochs: max epochs for scheduler + """ + super().__init__() + self.save_hyperparameters() + + self.config = config + self.transformations = transformations + self.online_network = SiameseArm( + encoder=self.init_model(), out_dim=config["model"]["out_dim"]) + self.target_network = deepcopy(self.online_network) + self.weight_callback = BYOLMAWeightUpdate() + self.log_dict = {} + self.epoch = 0 + # self.model_device = self.online_network.encoder.features[0][0].weight.device + + def init_model(self): + model = ResNetSimCLR(**self.config["model"]) + # return model.features + return model + + # def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + # Add callback for user automatically since it's key to BYOL weight update + self.weight_callback.on_train_batch_end( + self.trainer, self, outputs, batch, batch_idx, 0) + + def forward(self, x): + y, _, _ = self.online_network(x) + return y + + def cosine_similarity(self, a, b): + a = F.normalize(a, dim=-1) + b = F.normalize(b, dim=-1) + sim = (a * b).sum(-1).mean() + return sim + + def shared_step(self, batch, batch_idx): + # (img_1, img_2), y = batch + (img_1, y1), (img_2, y2) = batch + + img_1 = self.to_device(img_1) + img_2 = self.to_device(img_2) + + # Image 1 to image 2 loss + y1, z1, h1 = self.online_network(img_1) + with torch.no_grad(): + y2, z2, h2 = self.target_network(img_2) + loss_a = - 2 * self.cosine_similarity(h1, z2) + + # Image 2 to image 1 loss + y1, z1, h1 = self.online_network(img_2) + with torch.no_grad(): + y2, z2, h2 = self.target_network(img_1) + # L2 normalize + loss_b = - 2 * self.cosine_similarity(h1, z2) + + # Final loss + total_loss = loss_a + loss_b + + return loss_a, loss_b, total_loss + + def training_step(self, batch, batch_idx): + loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) + + # log results + # result = pl.TrainResult(minimize=total_loss) + # result.log('train_loss/1_2_loss', loss_a, on_epoch=True) + # result.log('train_loss/2_1_loss', loss_b, on_epoch=True) + # result.log('train_loss/total_loss', total_loss, on_epoch=True) + + # # log results + # self.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, + # 'train_loss': total_loss}) + + return total_loss + + def validation_step(self, batch, batch_idx, dataloader_idx): + if dataloader_idx != 0: + return {} + + loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) + + # # log results + # result = pl.EvalResult() + # result.log('val_loss/1_2_loss', loss_a, on_epoch=True) + # result.log('val_loss/2_1_loss', loss_b, on_epoch=True) + # result.log('val_loss/total_loss', total_loss, on_epoch=True) + + # self.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, + # 'train_loss': total_loss}) + results = { + 'val_loss': total_loss, + 'val_1_2_loss' : loss_a, + 'val_2_1_loss': loss_b + } + return results + + def validation_epoch_end(self, outputs): + # outputs[0] because we are using multiple datasets! + val_loss = mean(outputs[0], 'val_loss') + loss_a = mean(outputs[0], 'val_1_2_loss') + loss_b = mean(outputs[0], 'val_2_1_loss') + + log = { + 'val_loss': val_loss, + 'val_1_2_loss' : loss_a, + 'val_2_1_loss': loss_b + } + + return {'val_loss': val_loss, 'log': log, 'progress_bar': log} + + def configure_optimizers(self): + optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate, + weight_decay=self.hparams.weight_decay) + # optimizer = LARSWrapper(optimizer) + optimizer = optimizer + scheduler = LinearWarmupCosineAnnealingLR( + optimizer, + warmup_epochs=self.hparams.warmup_epochs, + max_epochs=self.hparams.max_epochs + ) + return [optimizer], [scheduler] + + def on_train_start(self): + # log configuration + config_str = re.sub(r"[,\}\{]", "<br/>", str(self.config)) + config_str = re.sub(r"[\[\]\']", "", config_str) + transformation_str = re.sub(r"[\}]", "<br/>", str(["<br>" + str( + t) + ":<br/>" + str(t.get_params()) for t in self.transformations])) + transformation_str = re.sub(r"[,\"\{\'\[\]]", "", transformation_str) + self.logger.experiment.add_text( + "configuration", str(config_str), global_step=0) + self.logger.experiment.add_text("transformations", str( + transformation_str), global_step=0) + self.epoch = 0 + + def on_epoch_end(self): + self.epoch += 1 + + def get_representations(self, x): + return self.online_network(x)[0] + + def get_model(self): + return self.online_network.encoder + + def get_device(self): + return self.online_network.encoder.features[0][0].weight.device + + def to_device(self, x): + return x.type(self.type()).to(self.get_device()) + + def type(self): + return self.online_network.encoder.features[0][0].weight.type() + +def parse_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument('-t', '--trafos', nargs='+', help='add transformation to data augmentation pipeline', + default=["GaussianNoise", "ChannelResize", "RandomResizedCrop"]) + # GaussianNoise + parser.add_argument( + '--gaussian_scale', help='std param for gaussian noise transformation', default=0.005, type=float) + # RandomResizedCrop + parser.add_argument('--rr_crop_ratio_range', + help='ratio range for random resized crop transformation', default=[0.5, 1.0], type=float) + parser.add_argument( + '--output_size', help='output size for random resized crop transformation', default=250, type=int) + # DynamicTimeWarp + parser.add_argument( + '--warps', help='number of warps for dynamic time warp transformation', default=3, type=int) + parser.add_argument( + '--radius', help='radius of warps of dynamic time warp transformation', default=10, type=int) + # TimeWarp + parser.add_argument( + '--epsilon', help='epsilon param for time warp', default=10, type=float) + # ChannelResize + parser.add_argument('--magnitude_range', nargs='+', + help='range for scale param for ChannelResize transformation', default=[0.5, 2], type=float) + # Downsample + parser.add_argument( + '--downsample_ratio', help='downsample ratio for Downsample transformation', default=0.2, type=float) + # TimeOut + parser.add_argument('--to_crop_ratio_range', nargs='+', + help='ratio range for timeout transformation', default=[0.2, 0.4], type=float) + # resume training + parser.add_argument('--resume', action='store_true') + parser.add_argument( + '--gpus', help='number of gpus to use; use cpu if gpu=0', type=int, default=1) + parser.add_argument( + '--num_nodes', default=1, help='number of cluster nodes', type=int) + parser.add_argument( + '--distributed_backend', help='sets backend type') + parser.add_argument('--batch_size', type=int) + parser.add_argument('--epochs', type=int) + parser.add_argument('--debug', action='store_true') + parser.add_argument('--warm_up', default=1, type=int) + parser.add_argument('--precision', type=int) + parser.add_argument('--datasets', dest="target_folders", + nargs='+', help='used datasets for pretraining') + parser.add_argument('--log_dir', default="./experiment_logs") + parser.add_argument( + '--percentage', help='determines how much of the dataset shall be used during the pretraining', type=float, default=1.0) + parser.add_argument('--lr', type=float, help="learning rate") + parser.add_argument('--out_dim', type=int, help="output dimension of model") + parser.add_argument('--filter_cinc', default=False, action="store_true", help="only valid if cinc is selected: filter out the ptb data") + parser.add_argument('--base_model') + parser.add_argument('--widen',type=int, help="use wide xresnet1d50") + parser.add_argument('--run_callbacks', default=False, action="store_true", help="run callbacks which asses linear evaluaton and finetuning metrics during pretraining") + + parser.add_argument('--checkpoint_path', default="") + return parser + +def init_logger(config): + level = logging.INFO + + if config['debug']: + level = logging.DEBUG + + # remove all handlers to change basic configuration + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + if not os.path.isdir(config['log_dir']): + os.mkdir(config['log_dir']) + logging.basicConfig(filename=os.path.join(config['log_dir'], 'info.log'), level=level, + format='%(asctime)s %(name)s:%(lineno)s %(levelname)s: %(message)s ') + return logging.getLogger(__name__) + +def pretrain_routine(args): + t_params = {"gaussian_scale": args.gaussian_scale, "rr_crop_ratio_range": args.rr_crop_ratio_range, "output_size": args.output_size, "warps": args.warps, "radius": args.radius, + "epsilon": args.epsilon, "magnitude_range": args.magnitude_range, "downsample_ratio": args.downsample_ratio, "to_crop_ratio_range": args.to_crop_ratio_range, + "bw_cmax":0.1, "em_cmax":0.5, "pl_cmax":0.2, "bs_cmax":1} + transformations = args.trafos + checkpoint_config = os.path.join("checkpoints", "bolts_config.yaml") + config_file = checkpoint_config if args.resume and os.path.isfile( + checkpoint_config) else "bolts_config.yaml" + config = yaml.load(open(config_file, "r"), Loader=yaml.FullLoader) + args_dict = vars(args) + for key in set(config.keys()).union(set(args_dict.keys())): + config[key] = config[key] if (key not in args_dict.keys() or key in args_dict.keys( + ) and key in config.keys() and args_dict[key] is None) else args_dict[key] + if args.target_folders is not None: + config["dataset"]["target_folders"] = args.target_folders + config["dataset"]["percentage"] = args.percentage if args.percentage is not None else config["dataset"]["percentage"] + config["dataset"]["filter_cinc"] = args.filter_cinc if args.filter_cinc is not None else config["dataset"]["filter_cinc"] + config["model"]["base_model"] = args.base_model if args.base_model is not None else config["model"]["base_model"] + config["model"]["widen"] = args.widen if args.widen is not None else config["model"]["widen"] + if args.out_dim is not None: + config["model"]["out_dim"] = args.out_dim + init_logger(config) + dataset = SimCLRDataSetWrapper( + config['batch_size'], **config['dataset'], transformations=transformations, t_params=t_params) + for i, t in enumerate(dataset.transformations): + logger.info(str(i) + ". Transformation: " + + str(t) + ": " + str(t.get_params())) + date = time.asctime() + label_to_num_classes = {"label_all": 71, "label_diag": 44, "label_form": 19, + "label_rhythm": 12, "label_diag_subclass": 23, "label_diag_superclass": 5} + ptb_num_classes = label_to_num_classes[config["eval_dataset"] + ["ptb_xl_label"]] + abr = {"Transpose": "Tr", "TimeOut": "TO", "DynamicTimeWarp": "DTW", "RandomResizedCrop": "RRC", "ChannelResize": "ChR", "GaussianNoise": "GN", + "TimeWarp": "TW", "ToTensor": "TT", "GaussianBlur": "GB", "BaselineWander": "BlW", "PowerlineNoise": "PlN", "EMNoise": "EM", "BaselineShift": "BlS"} + trs = re.sub(r"[,'\]\[]", "", str([abr[str(tr)] if abr[str(tr)] not in [ + "TT", "Tr"] else '' for tr in dataset.transformations])) + name = str(date) + "_" + method + "_" + str( + time.time_ns())[-3:] + "_" + trs[1:] + tb_logger = TensorBoardLogger(args.log_dir, name=name, version='') + config["log_dir"] = os.path.join(args.log_dir, name) + print(config) + return config, dataset, date, transformations, t_params, ptb_num_classes, tb_logger + +def aftertrain_routine(config, args, trainer, pl_model, datamodule, callbacks): + scores = {} + for ca in callbacks: + if isinstance(ca, SSLOnlineEvaluator): + scores[str(ca)] = {"macro": ca.best_macro} + + results = {"config": config, "trafos": args.trafos, "scores": scores} + + with open(os.path.join(config["log_dir"], "results.pkl"), 'wb') as handle: + pickle.dump(results, handle) + + trainer.save_checkpoint(os.path.join(config["log_dir"], "checkpoints", "model.ckpt")) + with open(os.path.join(config["log_dir"], "config.txt"), "w") as text_file: + print(config, file=text_file) + +def cli_main(): + from pytorch_lightning import Trainer + from online_evaluator import SSLOnlineEvaluator + from ecg_datamodule import ECGDataModule + from clinical_ts.create_logger import create_logger + from os.path import exists + + parser = ArgumentParser() + parser = parse_args(parser) + logger.info("parse arguments") + args = parser.parse_args() + config, dataset, date, transformations, t_params, ptb_num_classes, tb_logger = pretrain_routine(args) + + # data + ecg_datamodule = ECGDataModule(config, transformations, t_params) + + callbacks = [] + if args.run_callbacks: + # callback for online linear evaluation/fine-tuning + linear_evaluator = SSLOnlineEvaluator(drop_p=0, + z_dim=512, num_classes=ptb_num_classes, hidden_dim=None, lin_eval_epochs=config["eval_epochs"], eval_every=config["eval_every"], mode="linear_evaluation", verbose=False) + + fine_tuner = SSLOnlineEvaluator(drop_p=0, + z_dim=512, num_classes=ptb_num_classes, hidden_dim=None, lin_eval_epochs=config["eval_epochs"], eval_every=config["eval_every"], mode="fine_tuning", verbose=False) + + callbacks.append(linear_evaluator) + callbacks.append(fine_tuner) + + # configure trainer + trainer = Trainer(logger=tb_logger, max_epochs=config["epochs"], gpus=args.gpus, + distributed_backend=args.distributed_backend, auto_lr_find=False, num_nodes=args.num_nodes, precision=config["precision"], callbacks=callbacks) + + # pytorch lightning module + pl_model = CustomBYOL(5, learning_rate=config["lr"], weight_decay=eval(config["weight_decay"]), + warm_up_epochs=config["warm_up"], max_epochs=config[ + "epochs"], num_workers=config["dataset"]["num_workers"], + batch_size=config["batch_size"], config=config, transformations=ecg_datamodule.transformations) + + + # load checkpoint + if args.checkpoint_path != "": + if exists(args.checkpoint_path): + logger.info("Retrieve checkpoint from " + args.checkpoint_path) + pl_model.load_from_checkpoint(args.checkpoint_path) + else: + raise("checkpoint does not exist") + + # start training + trainer.fit(pl_model, ecg_datamodule) + + aftertrain_routine(config, args, trainer, pl_model, ecg_datamodule, callbacks) + +if __name__ == "__main__": + cli_main() \ No newline at end of file