Diff of /custom_simclr_bolts.py [000000] .. [134fd7]

Switch to side-by-side view

--- a
+++ b/custom_simclr_bolts.py
@@ -0,0 +1,507 @@
+import pytorch_lightning as pl
+# from pl_bolts.models.self_supervised import SimCLR
+from pl_bolts.optimizers.lars_scheduling import LARSWrapper
+from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
+from torch.optim import Adam
+import torch
+import re
+import pdb 
+
+import math
+from argparse import ArgumentParser
+from typing import Callable, Optional
+
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from pytorch_lightning.utilities import AMPType
+from torch import nn
+from torch.optim.optimizer import Optimizer
+
+
+from models.resnet_simclr import ResNetSimCLR
+import re
+
+import time
+import pickle
+import yaml
+import logging
+import os
+from clinical_ts.simclr_dataset_wrapper import SimCLRDataSetWrapper
+from clinical_ts.create_logger import create_logger
+import pickle
+from pytorch_lightning import Trainer, seed_everything
+
+from torch import nn
+from torch.nn import functional as F
+from online_evaluator import SSLOnlineEvaluator
+from ecg_datamodule import ECGDataModule
+from pytorch_lightning.loggers import TensorBoardLogger
+from pl_bolts.models.self_supervised.evaluator import Flatten
+import pdb
+method="simclr"
+logger = create_logger(__name__)
+def _accuracy(zis, zjs, batch_size):
+    with torch.no_grad():
+        representations = torch.cat([zjs, zis], dim=0)
+        similarity_matrix = torch.mm(
+            representations, representations.t().contiguous())
+        corrected_similarity_matrix = similarity_matrix - \
+            torch.eye(2*batch_size).type_as(similarity_matrix)
+        pred_similarities, pred_indices = torch.max(
+            corrected_similarity_matrix[:batch_size], dim=1)
+        correct_indices = torch.arange(batch_size)+batch_size
+        correct_preds = (
+            pred_indices == correct_indices.type_as(pred_indices)).sum()
+    return correct_preds.float()/batch_size
+
+def mean(res, key1, key2=None):
+    if key2 is not None:
+        return torch.stack([x[key1][key2] for x in res]).mean()
+    return torch.stack([x[key1] for x in res if type(x) == dict and key1 in x.keys()]).mean()
+
+class Projection(nn.Module):
+    def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=128):
+        super().__init__()
+        self.output_dim = output_dim
+        self.input_dim = input_dim
+        self.hidden_dim = hidden_dim
+        self.model = nn.Sequential(
+            # nn.AdaptiveAvgPool2d((1, 1)),
+            Flatten(),
+            nn.Linear(self.input_dim, self.hidden_dim, bias=True),
+            # nn.BatchNorm1d(self.hidden_dim),
+            nn.ReLU(),
+            nn.Linear(self.hidden_dim, self.output_dim, bias=True))
+
+    def forward(self, x):
+        x = self.model(x)
+        return F.normalize(x, dim=1)
+
+
+class SyncFunction(torch.autograd.Function):
+
+    @staticmethod
+    def forward(ctx, tensor):
+        ctx.batch_size = tensor.shape[0]
+
+        gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
+
+        torch.distributed.all_gather(gathered_tensor, tensor)
+        gathered_tensor = torch.cat(gathered_tensor, 0)
+
+        return gathered_tensor
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        grad_input = grad_output.clone()
+        torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
+
+        return grad_input[torch.distributed.get_rank() * ctx.batch_size:(torch.distributed.get_rank() + 1) *
+                          ctx.batch_size]
+
+
+class CustomSimCLR(pl.LightningModule):
+
+    def __init__(self,
+                 batch_size,
+                 num_samples,
+                 warmup_epochs=10,
+                 lr=1e-4,
+                 opt_weight_decay=1e-6,
+                 loss_temperature=0.5,
+                 config=None,
+                 transformations=None,
+                 **kwargs):
+        """
+        Args:
+            batch_size: the batch size
+            num_samples: num samples in the dataset
+            warmup_epochs: epochs to warmup the lr for
+            lr: the optimizer learning rate
+            opt_weight_decay: the optimizer weight decay
+            loss_temperature: the loss temperature
+        """
+
+        super(CustomSimCLR, self).__init__()
+        self.config = config
+        self.transformations = transformations
+        self.epoch = 0
+        self.batch_size = batch_size
+        self.num_samples = num_samples
+        self.save_hyperparameters()
+        # pdb.set_trace()
+
+    def configure_optimizers(self):
+        global_batch_size = self.trainer.world_size * self.hparams.batch_size
+        self.train_iters_per_epoch = self.hparams.num_samples // global_batch_size
+        # TRICK 1 (Use lars + filter weights)
+        # exclude certain parameters
+        parameters = self.exclude_from_wt_decay(
+            self.named_parameters(),
+            weight_decay=self.hparams.opt_weight_decay
+        )
+
+
+        # optimizer = LARSWrapper(Adam(parameters, lr=self.hparams.lr))
+        optimizer = Adam(parameters, lr=self.hparams.lr)
+        
+        # Trick 2 (after each step)
+        self.hparams.warmup_epochs = self.hparams.warmup_epochs * self.train_iters_per_epoch
+        max_epochs = self.trainer.max_epochs * self.train_iters_per_epoch
+
+        linear_warmup_cosine_decay = LinearWarmupCosineAnnealingLR(
+            optimizer,
+            warmup_epochs=self.hparams.warmup_epochs,
+            max_epochs=max_epochs,
+            warmup_start_lr=0,
+            eta_min=0
+        )
+
+        scheduler = {
+            'scheduler': linear_warmup_cosine_decay,
+            'interval': 'step',
+            'frequency': 1
+        }
+
+        return [optimizer], [scheduler]
+
+    def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=['bias', 'bn']):
+        params = []
+        excluded_params = []
+
+        for name, param in named_params:
+            if not param.requires_grad:
+                continue
+            elif any(layer_name in name for layer_name in skip_list):
+                excluded_params.append(param)
+            else:
+                params.append(param)
+
+        return [
+            {'params': params, 'weight_decay': weight_decay},
+            {'params': excluded_params, 'weight_decay': 0.}
+        ]
+    
+    def shared_forward(self, batch, batch_idx):
+        (x1, y1), (x2, y2) = batch
+        # ENCODE
+        # encode -> representations
+        # (b, 3, 32, 32) -> (b, 2048, 2, 2)
+        x1 = self.to_device(x1)
+        x2 = self.to_device(x2)
+
+        h1 = self.encoder(x1)[0]
+        h2 = self.encoder(x2)[0]
+
+        # the bolts resnets return a list of feature maps
+        if isinstance(h1, list):
+            h1 = h1[-1]
+            h2 = h2[-1]
+
+        # PROJECT
+        # img -> E -> h -> || -> z
+        # (b, 2048, 2, 2) -> (b, 128)
+        z1 = self.projection(h1.squeeze())
+        z2 = self.projection(h2.squeeze())
+
+        return z1, z2
+
+    def nt_xent_loss(self, out_1, out_2, temperature, eps=1e-6):
+        """
+            assume out_1 and out_2 are normalized
+            out_1: [batch_size, dim]
+            out_2: [batch_size, dim]
+        """
+        # gather representations in case of distributed training
+        # out_1_dist: [batch_size * world_size, dim]
+        # out_2_dist: [batch_size * world_size, dim]
+        if torch.distributed.is_available() and torch.distributed.is_initialized():
+            out_1_dist = SyncFunction.apply(out_1)
+            out_2_dist = SyncFunction.apply(out_2)
+            print("out dist shape: ", out_1_dist.shape)
+        else:
+            out_1_dist = out_1
+            out_2_dist = out_2
+        
+        # out: [2 * batch_size, dim]
+        # out_dist: [2 * batch_size * world_size, dim]
+        out = torch.cat([out_1, out_2], dim=0)
+        out_dist = torch.cat([out_1_dist, out_2_dist], dim=0)
+
+        # cov and sim: [2 * batch_size, 2 * batch_size * world_size]
+        # neg: [2 * batch_size]
+        cov = torch.mm(out, out_dist.t().contiguous())
+        sim = torch.exp(cov / temperature)
+        neg = sim.sum(dim=-1)
+
+        # from each row, subtract e^1 to remove similarity measure for x1.x1
+        row_sub = torch.Tensor(neg.shape).fill_(math.e).to(neg.device)
+        neg = torch.clamp(neg - row_sub, min=eps)  # clamp for numerical stability
+
+        # Positive similarity, pos becomes [2 * batch_size]
+        pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
+        pos = torch.cat([pos, pos], dim=0)
+
+        loss = -torch.log(pos / (neg + eps)).mean()
+
+        return loss
+
+    def training_step(self, batch, batch_idx):
+        z1, z2 = self.shared_forward(batch, batch_idx)
+        loss = self.nt_xent_loss(z1, z2, self.hparams.loss_temperature)
+        # result = pl.TrainResult(minimize=loss)
+        # result.log('train/train_loss', loss, on_epoch=True)
+
+        acc = _accuracy(z1, z2, z1.shape[0])
+        # result.log('train/train_acc', acc, on_epoch=True)
+        result = {
+            "train/train_loss": loss, 
+            "minimize":loss,
+            "train/train_acc" : acc,
+        }
+        return loss
+
+    def validation_step(self, batch, batch_idx, dataloader_idx):
+        if dataloader_idx != 0:
+            return {}
+        z1, z2 = self.shared_forward(batch, batch_idx)
+        loss = self.nt_xent_loss(z1, z2, self.hparams.loss_temperature)
+
+        acc = _accuracy(z1, z2, z1.shape[0])
+        results = {
+            'val_loss': loss,
+            'val_acc': torch.tensor(acc)
+        }
+        return results
+
+    def validation_epoch_end(self, outputs):
+        # outputs[0] because we are using multiple datasets!
+        val_loss = mean(outputs[0], 'val_loss')
+        val_acc = mean(outputs[0], 'val_acc')
+
+        log = {
+            'val/val_loss': val_loss,
+            'val/val_acc': val_acc
+        }
+        return {'val_loss': val_loss, 'log': log, 'progress_bar': log}
+
+    def on_train_start(self):
+        # log configuration
+        config_str = re.sub(r"[,\}\{]", "<br/>", str(self.config))
+        config_str = re.sub(r"[\[\]\']", "", config_str)
+        transformation_str = re.sub(r"[\}]", "<br/>", str(["<br>" + str(
+            t) + ":<br/>" + str(t.get_params()) for t in self.transformations]))
+        transformation_str = re.sub(r"[,\"\{\'\[\]]", "", transformation_str)
+        self.logger.experiment.add_text(
+            "configuration", str(config_str), global_step=0)
+        self.logger.experiment.add_text("transformations", str(
+            transformation_str), global_step=0)
+        self.epoch = 0
+
+    def on_epoch_end(self):
+        self.epoch += 1
+
+    def type(self):
+        return self.encoder.features[0][0].weight.type()
+
+    def get_representations(self, x):
+        return self.encoder(x)[0]
+    
+    def get_model(self):
+        return self.encoder
+
+    def get_device(self):
+        return self.encoder.features[0][0].weight.device
+
+    def to_device(self, x):
+        return x.type(self.type()).to(self.get_device())
+
+
+def parse_args(parent_parser):
+    parser = ArgumentParser(parents=[parent_parser], add_help=False)
+    parser.add_argument('-t', '--trafos', nargs='+', help='add transformation to data augmentation pipeline',
+                        default=["GaussianNoise", "ChannelResize", "RandomResizedCrop"])
+    # GaussianNoise
+    parser.add_argument(
+            '--gaussian_scale', help='std param for gaussian noise transformation', default=0.005, type=float)
+    # RandomResizedCrop
+    parser.add_argument('--rr_crop_ratio_range',
+                            help='ratio range for random resized crop transformation', default=[0.5, 1.0], type=float)
+    parser.add_argument(
+            '--output_size', help='output size for random resized crop transformation', default=250, type=int)
+    # DynamicTimeWarp
+    parser.add_argument(
+            '--warps', help='number of warps for dynamic time warp transformation', default=3, type=int)
+    parser.add_argument(
+            '--radius', help='radius of warps of dynamic time warp transformation', default=10, type=int)
+    # TimeWarp
+    parser.add_argument(
+            '--epsilon', help='epsilon param for time warp', default=10, type=float)
+    # ChannelResize
+    parser.add_argument('--magnitude_range', nargs='+',
+                            help='range for scale param for ChannelResize transformation', default=[0.5, 2], type=float)
+    # Downsample
+    parser.add_argument(
+            '--downsample_ratio', help='downsample ratio for Downsample transformation', default=0.2, type=float)
+    # TimeOut
+    parser.add_argument('--to_crop_ratio_range', nargs='+',
+                            help='ratio range for timeout transformation', default=[0.2, 0.4], type=float)
+    # resume training
+    parser.add_argument('--resume', action='store_true')
+    parser.add_argument(
+            '--gpus', help='number of gpus to use; use cpu if gpu=0', type=int, default=1)
+    parser.add_argument(
+            '--num_nodes', default=1, help='number of cluster nodes', type=int)
+    parser.add_argument(
+            '--distributed_backend', help='sets backend type')
+    parser.add_argument('--batch_size', type=int)
+    parser.add_argument('--epochs', type=int)
+    parser.add_argument('--debug', action='store_true')
+    parser.add_argument('--warm_up', default=1, type=int, help="number of warm up epochs")
+    parser.add_argument('--precision', type=int)
+    parser.add_argument('--datasets', dest="target_folders",
+                            nargs='+', help='used datasets for pretraining')
+    parser.add_argument('--log_dir', default="./experiment_logs")
+    parser.add_argument(
+            '--percentage', help='determines how much of the dataset shall be used during the pretraining', type=float, default=1.0)
+    parser.add_argument('--lr', type=float, help="learning rate")
+    parser.add_argument('--out_dim', type=int, help="output dimension of model")
+    parser.add_argument('--filter_cinc', default=False, action="store_true", help="only valid if cinc is selected: filter out the ptb data")
+    parser.add_argument('--base_model')
+    parser.add_argument('--widen',type=int, help="use wide xresnet1d50")
+    parser.add_argument('--run_callbacks', default=False, action="store_true", help="run callbacks which asses linear evaluaton and finetuning metrics during pretraining")
+    parser.add_argument('--checkpoint_path', default="")
+    return parser
+
+def init_logger(config):
+    level = logging.INFO
+
+    if config['debug']:
+        level = logging.DEBUG
+
+    # remove all handlers to change basic configuration
+    for handler in logging.root.handlers[:]:
+        logging.root.removeHandler(handler)
+    if not os.path.isdir(config['log_dir']):
+        os.mkdir(config['log_dir'])
+    logging.basicConfig(filename=os.path.join(config['log_dir'], 'info.log'), level=level,
+                        format='%(asctime)s %(name)s:%(lineno)s %(levelname)s:  %(message)s  ')
+    return logging.getLogger(__name__)
+
+def pretrain_routine(args):
+    t_params = {"gaussian_scale": args.gaussian_scale, "rr_crop_ratio_range": args.rr_crop_ratio_range, "output_size": args.output_size, "warps": args.warps, "radius": args.radius,
+                "epsilon": args.epsilon, "magnitude_range": args.magnitude_range, "downsample_ratio": args.downsample_ratio, "to_crop_ratio_range": args.to_crop_ratio_range,
+                "bw_cmax":0.1, "em_cmax":0.5, "pl_cmax":0.2, "bs_cmax":1}
+    transformations = args.trafos
+    checkpoint_config = os.path.join("checkpoints", "bolts_config.yaml")
+    config_file = checkpoint_config if args.resume and os.path.isfile(
+        checkpoint_config) else "bolts_config.yaml"
+    config = yaml.load(open(config_file, "r"), Loader=yaml.FullLoader)
+    args_dict = vars(args)
+    for key in set(config.keys()).union(set(args_dict.keys())):
+        config[key] = config[key] if (key not in args_dict.keys() or key in args_dict.keys(
+        ) and key in config.keys() and args_dict[key] is None) else args_dict[key]
+    if args.target_folders is not None:
+        config["dataset"]["target_folders"] = args.target_folders
+    config["dataset"]["percentage"] = args.percentage if args.percentage is not None else config["dataset"]["percentage"]
+    config["dataset"]["filter_cinc"] = args.filter_cinc if args.filter_cinc is not None else config["dataset"]["filter_cinc"]
+    config["model"]["base_model"] = args.base_model if args.base_model is not None else config["model"]["base_model"]
+    config["model"]["widen"] = args.widen if args.widen is not None else config["model"]["widen"]
+    if args.out_dim is not None:
+        config["model"]["out_dim"] = args.out_dim
+    init_logger(config)
+    dataset = SimCLRDataSetWrapper(
+        config['batch_size'], **config['dataset'], transformations=transformations, t_params=t_params)
+    for i, t in enumerate(dataset.transformations):
+        logger.info(str(i) + ". Transformation: " +
+                    str(t) + ": " + str(t.get_params()))
+    date = time.asctime()
+    label_to_num_classes = {"label_all": 71, "label_diag": 44, "label_form": 19,
+                            "label_rhythm": 12, "label_diag_subclass": 23, "label_diag_superclass": 5}
+    ptb_num_classes = label_to_num_classes[config["eval_dataset"]
+                                           ["ptb_xl_label"]]
+    abr = {"Transpose": "Tr", "TimeOut": "TO", "DynamicTimeWarp": "DTW", "RandomResizedCrop": "RRC", "ChannelResize": "ChR", "GaussianNoise": "GN",
+           "TimeWarp": "TW", "ToTensor": "TT", "GaussianBlur": "GB", "BaselineWander": "BlW", "PowerlineNoise": "PlN", "EMNoise": "EM", "BaselineShift": "BlS"}
+    trs = re.sub(r"[,'\]\[]", "", str([abr[str(tr)] if abr[str(tr)] not in [
+                 "TT", "Tr"] else '' for tr in dataset.transformations]))
+    name = str(date) + "_" + method + "_" + str(
+        time.time_ns())[-3:] + "_" + trs[1:]
+    tb_logger = TensorBoardLogger(args.log_dir, name=name, version='')
+    config["log_dir"] = os.path.join(args.log_dir, name)
+    print(config)
+    return config, dataset, date, transformations, t_params, ptb_num_classes, tb_logger
+
+def aftertrain_routine(config, args, trainer, pl_model, datamodule, callbacks):
+    scores = {}
+    for ca in callbacks:
+        if isinstance(ca, SSLOnlineEvaluator):
+            scores[str(ca)] = {"macro": ca.best_macro}
+
+    results = {"config": config, "trafos": args.trafos, "scores": scores}
+
+    with open(os.path.join(config["log_dir"], "results.pkl"), 'wb') as handle:
+        pickle.dump(results, handle)
+
+    trainer.save_checkpoint(os.path.join(config["log_dir"], "checkpoints", "model.ckpt"))
+    with open(os.path.join(config["log_dir"], "config.txt"), "w") as text_file:
+        print(config, file=text_file)
+
+def cli_main():
+    from pytorch_lightning import Trainer
+    from online_evaluator import SSLOnlineEvaluator
+    from ecg_datamodule import ECGDataModule
+    from clinical_ts.create_logger import create_logger
+    from os.path import exists
+    
+    parser = ArgumentParser()
+    parser = parse_args(parser)
+    logger.info("parse arguments")
+    args = parser.parse_args()
+    config, dataset, date, transformations, t_params, ptb_num_classes, tb_logger = pretrain_routine(args)
+
+    # data
+    ecg_datamodule = ECGDataModule(config, transformations, t_params)
+
+    callbacks = []
+    if args.run_callbacks:
+            # callback for online linear evaluation/fine-tuning
+        linear_evaluator = SSLOnlineEvaluator(drop_p=0,
+                                          z_dim=512, num_classes=ptb_num_classes, hidden_dim=None, lin_eval_epochs=config["eval_epochs"], eval_every=config["eval_every"], mode="linear_evaluation", verbose=False)
+
+        fine_tuner = SSLOnlineEvaluator(drop_p=0,
+                                          z_dim=512, num_classes=ptb_num_classes, hidden_dim=None, lin_eval_epochs=config["eval_epochs"], eval_every=config["eval_every"], mode="fine_tuning", verbose=False)
+   
+        callbacks.append(linear_evaluator)
+        callbacks.append(fine_tuner)
+
+    # configure trainer
+    trainer = Trainer(logger=tb_logger, max_epochs=config["epochs"], gpus=args.gpus,
+                      distributed_backend=args.distributed_backend, auto_lr_find=False, num_nodes=args.num_nodes, precision=config["precision"], callbacks=callbacks)
+
+    # pytorch lightning module
+    model = ResNetSimCLR(**config["model"])
+    pl_model = CustomSimCLR(
+            config["batch_size"], ecg_datamodule.num_samples, warmup_epochs=config["warm_up"], lr=config["lr"],
+            out_dim=config["model"]["out_dim"], config=config,
+            transformations=ecg_datamodule.transformations, loss_temperature=config["loss"]["temperature"], weight_decay=eval(config["weight_decay"]))
+    pl_model.encoder = model
+    pl_model.projection = Projection(
+            input_dim=model.l1.in_features, hidden_dim=512, output_dim=config["model"]["out_dim"])
+
+    # load checkpoint
+    if args.checkpoint_path != "":
+        if exists(args.checkpoint_path):
+            logger.info("Retrieve checkpoint from " + args.checkpoint_path)
+            pl_model.load_from_checkpoint(args.checkpoint_path)
+        else:
+            raise("checkpoint does not exist")
+
+    # start training
+    trainer.fit(pl_model, ecg_datamodule)
+
+    aftertrain_routine(config, args, trainer, pl_model, ecg_datamodule, callbacks)
+
+if __name__ == "__main__":  
+    cli_main()
\ No newline at end of file