--- a +++ b/custom_moco_bolts.py @@ -0,0 +1,603 @@ +import pytorch_lightning as pl +from pl_bolts.models.self_supervised import MocoV2 +from pl_bolts.optimizers.lars_scheduling import LARSWrapper +from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR +from torch.optim import Adam +import torch +import re +import pdb +from argparse import ArgumentParser +from typing import Union +from warnings import warn +import torch.nn.functional as F +from torch import nn +from pl_bolts.metrics import precision_at_k # , mean +from clinical_ts.create_logger import create_logger +from models.resnet_simclr import ResNetSimCLR +import re + +import time + +import yaml +import logging +import pickle +import os +from clinical_ts.simclr_dataset_wrapper import SimCLRDataSetWrapper +from clinical_ts.create_logger import create_logger +import pickle +from pytorch_lightning import Trainer, seed_everything + +from torch import nn +from pytorch_lightning.loggers import TensorBoardLogger +from pl_bolts.models.self_supervised.evaluator import Flatten +import pdb +logger = create_logger(__name__) +method="moco" + +def _accuracy(zis, zjs, batch_size): + with torch.no_grad(): + representations = torch.cat([zjs, zis], dim=0) + similarity_matrix = torch.mm( + representations, representations.t().contiguous()) + corrected_similarity_matrix = similarity_matrix - \ + torch.eye(2*batch_size).type_as(similarity_matrix) + pred_similarities, pred_indices = torch.max( + corrected_similarity_matrix[:batch_size], dim=1) + correct_indices = torch.arange(batch_size)+batch_size + correct_preds = ( + pred_indices == correct_indices.type_as(pred_indices)).sum() + return correct_preds.float()/batch_size + + +def mean(res, key1, key2=None): + if key2 is not None: + return torch.stack([x[key1][key2] for x in res]).mean() + return torch.stack([x[key1] for x in res if type(x) == dict and key1 in x.keys()]).mean() + +# utils + + +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [torch.ones_like(tensor) + for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + + +class CustomMoCo(pl.LightningModule): + + def __init__(self, + base_encoder, + emb_dim: int = 128, + num_negatives: int = 65536, + encoder_momentum: float = 0.999, + softmax_temperature: float = 0.07, + learning_rate: float = 0.03, + momentum: float = 0.9, + weight_decay: float = 1e-6, + datamodule: pl.LightningDataModule = None, + data_dir: str = './', + batch_size: int = 256, + use_mlp: bool = False, + num_workers: int = 8, + config=None, + transformations=None, + warmup_epochs=10, + *args, **kwargs): + + super(CustomMoCo, self).__init__() + self.base_encoder = base_encoder + self.emb_dim = emb_dim + self.num_negatives = num_negatives + self.encoder_momentum = encoder_momentum + self.softmax_temperature = softmax_temperature + self.learning_rate = learning_rate + self.momentum = momentum + self.weight_decay = weight_decay + self.datamodule = datamodule + self.data_dir = data_dir + self.batch_size = batch_size + self.use_mlp = use_mlp + self.num_workers = num_workers + self.warmup_epochs = warmup_epochs + self.config = config + self.transformations = transformations + self.epoch = 0 + # create the encoders + # num_classes is the output fc dimension + self.encoder_q, self.encoder_k = self.init_encoders(base_encoder) + + if use_mlp: # hack: brute-force replacement + dim_mlp = self.encoder_q.fc.weight.shape[1] + self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) + self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) + + for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): + param_k.data.copy_(param_q.data) # initialize + param_k.requires_grad = False # not update by gradient + + # create the queue + self.register_buffer("queue", torch.randn(emb_dim, num_negatives)) + self.queue = nn.functional.normalize(self.queue, dim=0) + + self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) + # self.warmup_epochs = config["warm_up"] + + def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=['bias', 'bn']): + params = [] + excluded_params = [] + + for name, param in named_params: + if not param.requires_grad: + continue + elif any(layer_name in name for layer_name in skip_list): + excluded_params.append(param) + else: + params.append(param) + + return [ + {'params': params, 'weight_decay': weight_decay}, + {'params': excluded_params, 'weight_decay': 0.} + ] + + def setup(self, stage): + global_batch_size = self.trainer.world_size * self.batch_size + self.train_iters_per_epoch = self.datamodule.num_samples // global_batch_size + + + def configure_optimizers(self): + # global_batch_size = self.trainer.world_size * self.batch_size + # self.train_iters_per_epoch = self.datamodule.num_samples // global_batch_size + # # TRICK 1 (Use lars + filter weights) + # # exclude certain parameters + # parameters = self.exclude_from_wt_decay( + # self.named_parameters(), + # weight_decay=self.weight_decay + # ) + + # optimizer = LARSWrapper(Adam(parameters, lr=self.learning_rate)) + + # # Trick 2 (after each step) + # self.warmup_epochs = self.warmup_epochs * self.train_iters_per_epoch + # max_epochs = self.trainer.max_epochs * self.train_iters_per_epoch + + # linear_warmup_cosine_decay = LinearWarmupCosineAnnealingLR( + # optimizer, + # warmup_epochs=self.warmup_epochs, + # max_epochs=max_epochs, + # warmup_start_lr=0, + # eta_min=0 + # ) + + # scheduler = { + # 'scheduler': linear_warmup_cosine_decay, + # 'interval': 'step', + # 'frequency': 1 + # } + + # self.scheduler = linear_warmup_cosine_decay + + logger.debug("configure_optimizers") + optimizer = torch.optim.Adam(self.parameters( + ), self.learning_rate, weight_decay=self.weight_decay) + self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.config["epochs"], eta_min=0, + last_epoch=-1) + return [optimizer], [self.scheduler] + + def init_encoders(self, base_encoder): + """ + Override to add your own encoders + """ + + encoder_q = base_encoder() + encoder_k = base_encoder() + + return encoder_q, encoder_k + + def forward(self, img_q, img_k): + """ + Input: + im_q: a batch of query images + im_k: a batch of key images + Output: + logits, targets + """ + + # ugly fix + img_q = img_q.type_as(self.encoder_q.features[0][0].weight.data) + img_k = img_k.type_as(self.encoder_q.features[0][0].weight.data) + + # compute query features + q = self.encoder_q(img_q)[1] # queries: NxC + q = nn.functional.normalize(q, dim=1) + + # compute key features + with torch.no_grad(): # no gradient to keys + self._momentum_update_key_encoder() # update the key encoder + + # shuffle for making use of BN + if self.use_ddp or self.use_ddp2: + img_k, idx_unshuffle = self._batch_shuffle_ddp(img_k) + + k = self.encoder_k(img_k)[1] # keys: NxC + k = nn.functional.normalize(k, dim=1) + + # undo shuffle + if self.use_ddp or self.use_ddp2: + k = self._batch_unshuffle_ddp(k, idx_unshuffle) + + # compute logits + # Einstein sum is more intuitive + # positive logits: Nx1 + + l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) + # negative logits: NxK + l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) + + # logits: Nx(1+K) + logits = torch.cat([l_pos, l_neg], dim=1) + + # apply temperature + logits /= self.softmax_temperature + + # labels: positive key indicators + labels = torch.zeros(logits.shape[0], dtype=torch.long) + labels = labels.type_as(logits) + + # dequeue and enqueue + self._dequeue_and_enqueue(k) + + return logits, labels + + def training_step(self, batch, batch_idx): + (img_1, _), (img_2, _) = batch + output, target = self(img_q=img_1.float(), img_k=img_2.float()) + loss = F.cross_entropy(output.float(), target.long()) + acc = precision_at_k(output, target, top_k=(1,))[0] + + log = { + 'train_loss': loss, + 'train_acc': acc + } + return {'loss': loss, 'log': log, 'progress_bar': log} + + def validation_step(self, batch, batch_idx, dataloader_idx): + if dataloader_idx != 0: + return {} + + (img_1, _), (img_2, _) = batch + + output, target = self(img_q=img_1, img_k=img_2) + loss = F.cross_entropy(output, target.long()) + + acc = precision_at_k(output, target, top_k=(1,))[0] + results = { + 'val_loss': loss, + 'val_acc': acc + } + return results + + def training_epoch_end(self, outputs): + train_loss = mean(outputs, 'log', 'train_loss') + train_acc = mean(outputs, 'log', 'train_acc') + + log = { + 'train/train_loss': train_loss, + 'train/train_acc': train_acc + } + return {'train_loss': train_loss, 'log': log, 'progress_bar': log} + + def validation_epoch_end(self, outputs): + # outputs[0] because we are using multiple datasets! + val_loss = mean(outputs[0], 'val_loss') + val_acc = mean(outputs[0], 'val_acc') + + log = { + 'val/val_loss': val_loss, + 'val/val_acc': val_acc + } + return {'val_loss': val_loss, 'log': log, 'progress_bar': log} + + def on_train_start(self): + # log configuration + config_str = re.sub(r"[,\}\{]", "<br/>", str(self.config)) + config_str = re.sub(r"[\[\]\']", "", config_str) + transformation_str = re.sub(r"[\}]", "<br/>", str(["<br>" + str( + t) + ":<br/>" + str(t.get_params()) for t in self.transformations])) + transformation_str = re.sub(r"[,\"\{\'\[\]]", "", transformation_str) + self.logger.experiment.add_text( + "configuration", str(config_str), global_step=0) + self.logger.experiment.add_text("transformations", str( + transformation_str), global_step=0) + self.epoch = 0 + + def on_epoch_end(self): + # import pdb + # pdb.set_trace() + self.logger.experiment.add_scalar('cosine_lr_decay', self.scheduler.get_lr()[ + 0], global_step=self.epoch) + self.epoch += 1 + if self.epoch >= 10: + self.scheduler.step() + + @torch.no_grad() + def _dequeue_and_enqueue(self, keys): + # gather keys before updating queue + if self.use_ddp or self.use_ddp2: + keys = concat_all_gather(keys) + + batch_size = keys.shape[0] + # import pdb + # pdb.set_trace() + ptr = int(self.queue_ptr) + + # replace the keys at ptr (dequeue and enqueue) + remainder = self.queue[:, ptr:ptr + batch_size].shape[1] + if remainder < batch_size: + self.queue[:, -remainder:] = keys.T[:, :remainder] + self.queue[:, :batch_size-remainder] = keys.T[:, remainder:] + ptr = batch_size-remainder + else: + self.queue[:, ptr:ptr + batch_size] = keys.T + ptr = (ptr + batch_size) % self.num_negatives # move pointer + + self.queue_ptr[0] = ptr + + @torch.no_grad() + def _batch_shuffle_ddp(self, x): # pragma: no-cover + """ + Batch shuffle, for making use of BatchNorm. + *** Only support DistributedDataParallel (DDP) model. *** + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = concat_all_gather(x) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # random shuffle index + idx_shuffle = torch.randperm(batch_size_all).cuda() + + # broadcast to all gpus + torch.distributed.broadcast(idx_shuffle, src=0) + + # index for restoring + idx_unshuffle = torch.argsort(idx_shuffle) + + # shuffled index for this gpu + gpu_idx = torch.distributed.get_rank() + idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this], idx_unshuffle + + @torch.no_grad() + def _batch_unshuffle_ddp(self, x, idx_unshuffle): # pragma: no-cover + """ + Undo batch shuffle. + *** Only support DistributedDataParallel (DDP) model. *** + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = concat_all_gather(x) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # restored index for this gpu + gpu_idx = torch.distributed.get_rank() + idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this] + + @torch.no_grad() + def _momentum_update_key_encoder(self): + """ + Momentum update of the key encoder + """ + for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): + em = self.encoder_momentum + param_k.data = param_k.data * em + param_q.data * (1. - em) + + def type(self): + return self.encoder_k.features[0][0].weight.type() + + def get_representations(self, x): + return self.encoder_q.features(x) + + def get_model(self): + return self.encoder_q + + def get_device(self): + return self.encoder_k.features[0][0].weight.device + +def parse_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument('-t', '--trafos', nargs='+', help='add transformation to data augmentation pipeline', + default=["GaussianNoise", "ChannelResize", "RandomResizedCrop"]) + # GaussianNoise + parser.add_argument( + '--gaussian_scale', help='std param for gaussian noise transformation', default=0.005, type=float) + # RandomResizedCrop + parser.add_argument('--rr_crop_ratio_range', + help='ratio range for random resized crop transformation', default=[0.5, 1.0], type=float) + parser.add_argument( + '--output_size', help='output size for random resized crop transformation', default=250, type=int) + # DynamicTimeWarp + parser.add_argument( + '--warps', help='number of warps for dynamic time warp transformation', default=3, type=int) + parser.add_argument( + '--radius', help='radius of warps of dynamic time warp transformation', default=10, type=int) + # TimeWarp + parser.add_argument( + '--epsilon', help='epsilon param for time warp', default=10, type=float) + # ChannelResize + parser.add_argument('--magnitude_range', nargs='+', + help='range for scale param for ChannelResize transformation', default=[0.5, 2], type=float) + # Downsample + parser.add_argument( + '--downsample_ratio', help='downsample ratio for Downsample transformation', default=0.2, type=float) + # TimeOut + parser.add_argument('--to_crop_ratio_range', nargs='+', + help='ratio range for timeout transformation', default=[0.2, 0.4], type=float) + # resume training + parser.add_argument('--resume', action='store_true') + parser.add_argument( + '--gpus', help='number of gpus to use; use cpu if gpu=0', type=int, default=1) + parser.add_argument( + '--num_nodes', default=1, help='number of cluster nodes', type=int) + parser.add_argument( + '--distributed_backend', help='sets backend type') + parser.add_argument('--batch_size', type=int) + parser.add_argument('--epochs', type=int) + parser.add_argument('--debug', action='store_true') + parser.add_argument('--warm_up', default=1, type=int) + parser.add_argument('--precision', type=int) + parser.add_argument('--datasets', dest="target_folders", + nargs='+', help='used datasets for pretraining') + parser.add_argument('--log_dir', default="./experiment_logs") + parser.add_argument('--checkpoint_path', default="") + parser.add_argument( + '--percentage', help='determines how much of the dataset shall be used during the pretraining', type=float, default=1.0) + parser.add_argument('--lr', type=float, help="learning rate") + parser.add_argument('--out_dim', type=int, help="output dimension of model") + parser.add_argument('--filter_cinc', default=False, action="store_true", help="only valid if cinc is selected: filter out the ptb data") + parser.add_argument('--base_model') + parser.add_argument('--widen',type=int, help="use wide xresnet1d50") + parser.add_argument('--run_callbacks', default=False, action="store_true", help="run callbacks which asses linear evaluaton and finetuning metrics during pretraining") + + return parser + +def init_logger(config): + level = logging.INFO + + if config['debug']: + level = logging.DEBUG + + # remove all handlers to change basic configuration + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + if not os.path.isdir(config['log_dir']): + os.mkdir(config['log_dir']) + logging.basicConfig(filename=os.path.join(config['log_dir'], 'info.log'), level=level, + format='%(asctime)s %(name)s:%(lineno)s %(levelname)s: %(message)s ') + return logging.getLogger(__name__) + +def pretrain_routine(args): + t_params = {"gaussian_scale": args.gaussian_scale, "rr_crop_ratio_range": args.rr_crop_ratio_range, "output_size": args.output_size, "warps": args.warps, "radius": args.radius, + "epsilon": args.epsilon, "magnitude_range": args.magnitude_range, "downsample_ratio": args.downsample_ratio, "to_crop_ratio_range": args.to_crop_ratio_range, + "bw_cmax":0.1, "em_cmax":0.5, "pl_cmax":0.2, "bs_cmax":1} + transformations = args.trafos + checkpoint_config = os.path.join("checkpoints", "bolts_config.yaml") + config_file = checkpoint_config if args.resume and os.path.isfile( + checkpoint_config) else "bolts_config.yaml" + config = yaml.load(open(config_file, "r"), Loader=yaml.FullLoader) + args_dict = vars(args) + for key in set(config.keys()).union(set(args_dict.keys())): + config[key] = config[key] if (key not in args_dict.keys() or key in args_dict.keys( + ) and key in config.keys() and args_dict[key] is None) else args_dict[key] + if args.target_folders is not None: + config["dataset"]["target_folders"] = args.target_folders + config["dataset"]["percentage"] = args.percentage if args.percentage is not None else config["dataset"]["percentage"] + config["dataset"]["filter_cinc"] = args.filter_cinc if args.filter_cinc is not None else config["dataset"]["filter_cinc"] + config["model"]["base_model"] = args.base_model if args.base_model is not None else config["model"]["base_model"] + config["model"]["widen"] = args.widen if args.widen is not None else config["model"]["widen"] + if args.out_dim is not None: + config["model"]["out_dim"] = args.out_dim + init_logger(config) + dataset = SimCLRDataSetWrapper( + config['batch_size'], **config['dataset'], transformations=transformations, t_params=t_params) + for i, t in enumerate(dataset.transformations): + logger.info(str(i) + ". Transformation: " + + str(t) + ": " + str(t.get_params())) + date = time.asctime() + label_to_num_classes = {"label_all": 71, "label_diag": 44, "label_form": 19, + "label_rhythm": 12, "label_diag_subclass": 23, "label_diag_superclass": 5} + ptb_num_classes = label_to_num_classes[config["eval_dataset"] + ["ptb_xl_label"]] + abr = {"Transpose": "Tr", "TimeOut": "TO", "DynamicTimeWarp": "DTW", "RandomResizedCrop": "RRC", "ChannelResize": "ChR", "GaussianNoise": "GN", + "TimeWarp": "TW", "ToTensor": "TT", "GaussianBlur": "GB", "BaselineWander": "BlW", "PowerlineNoise": "PlN", "EMNoise": "EM", "BaselineShift": "BlS"} + trs = re.sub(r"[,'\]\[]", "", str([abr[str(tr)] if abr[str(tr)] not in [ + "TT", "Tr"] else '' for tr in dataset.transformations])) + name = str(date) + "_" + method + "_" + str( + time.time_ns())[-3:] + "_" + trs[1:] + tb_logger = TensorBoardLogger(args.log_dir, name=name, version='') + config["log_dir"] = os.path.join(args.log_dir, name) + print(config) + return config, dataset, date, transformations, t_params, ptb_num_classes, tb_logger + +def aftertrain_routine(config, args, trainer, pl_model, datamodule, callbacks): + scores = {} + for ca in callbacks: + if isinstance(ca, SSLOnlineEvaluator): + scores[str(ca)] = {"macro": ca.best_macro} + + results = {"config": config, "trafos": args.trafos, "scores": scores} + + with open(os.path.join(config["log_dir"], "results.pkl"), 'wb') as handle: + pickle.dump(results, handle) + + trainer.save_checkpoint(os.path.join(config["log_dir"], "checkpoints", "model.ckpt")) + with open(os.path.join(config["log_dir"], "config.txt"), "w") as text_file: + print(config, file=text_file) + +def cli_main(): + from pytorch_lightning import Trainer + from online_evaluator import SSLOnlineEvaluator + from ecg_datamodule import ECGDataModule + from clinical_ts.create_logger import create_logger + from os.path import exists + + parser = ArgumentParser() + parser = parse_args(parser) + logger.info("parse arguments") + args = parser.parse_args() + config, dataset, date, transformations, t_params, ptb_num_classes, tb_logger = pretrain_routine(args) + + # data + ecg_datamodule = ECGDataModule(config, transformations, t_params) + + callbacks = [] + if args.run_callbacks: + # callback for online linear evaluation/fine-tuning + linear_evaluator = SSLOnlineEvaluator(drop_p=0, + z_dim=512, num_classes=ptb_num_classes, hidden_dim=None, lin_eval_epochs=config["eval_epochs"], eval_every=config["eval_every"], mode="linear_evaluation", verbose=False) + + fine_tuner = SSLOnlineEvaluator(drop_p=0, + z_dim=512, num_classes=ptb_num_classes, hidden_dim=None, lin_eval_epochs=config["eval_epochs"], eval_every=config["eval_every"], mode="fine_tuning", verbose=False) + + callbacks.append(linear_evaluator) + callbacks.append(fine_tuner) + + # configure trainer + trainer = Trainer(logger=tb_logger, max_epochs=config["epochs"], gpus=args.gpus, + distributed_backend=args.distributed_backend, auto_lr_find=False, num_nodes=args.num_nodes, precision=config["precision"], callbacks=callbacks) + + # pytorch lightning module + def create_encoder(): return ResNetSimCLR(**config["model"]) + pl_model = CustomMoCo(create_encoder, datamodule=ecg_datamodule, num_negatives=ecg_datamodule.num_samples, + emb_dim=config["model"]["out_dim"], config=config, transformations=ecg_datamodule.transformations, + batch_size=config["batch_size"], learning_rate=config["lr"], softmax_temperature=config["lr"], + warmup_epochs=config["warm_up"], weight_decay=eval(config["weight_decay"])) + # load checkpoint + if args.checkpoint_path != "": + if exists(args.checkpoint_path): + logger.info("Retrieve checkpoint from " + args.checkpoint_path) + pl_model.load_from_checkpoint(args.checkpoint_path) + else: + raise("checkpoint does not exist") + + # start training + trainer.fit(pl_model, ecg_datamodule) + + aftertrain_routine(config, args, trainer, pl_model, ecg_datamodule, callbacks) + +if __name__ == "__main__": + cli_main() \ No newline at end of file