a b/custom_simclr_bolts.py
1
import pytorch_lightning as pl
2
# from pl_bolts.models.self_supervised import SimCLR
3
from pl_bolts.optimizers.lars_scheduling import LARSWrapper
4
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
5
from torch.optim import Adam
6
import torch
7
import re
8
import pdb 
9
10
import math
11
from argparse import ArgumentParser
12
from typing import Callable, Optional
13
14
import numpy as np
15
import torch
16
import torch.distributed as dist
17
import torch.nn.functional as F
18
from pytorch_lightning.utilities import AMPType
19
from torch import nn
20
from torch.optim.optimizer import Optimizer
21
22
23
from models.resnet_simclr import ResNetSimCLR
24
import re
25
26
import time
27
import pickle
28
import yaml
29
import logging
30
import os
31
from clinical_ts.simclr_dataset_wrapper import SimCLRDataSetWrapper
32
from clinical_ts.create_logger import create_logger
33
import pickle
34
from pytorch_lightning import Trainer, seed_everything
35
36
from torch import nn
37
from torch.nn import functional as F
38
from online_evaluator import SSLOnlineEvaluator
39
from ecg_datamodule import ECGDataModule
40
from pytorch_lightning.loggers import TensorBoardLogger
41
from pl_bolts.models.self_supervised.evaluator import Flatten
42
import pdb
43
method="simclr"
44
logger = create_logger(__name__)
45
def _accuracy(zis, zjs, batch_size):
46
    with torch.no_grad():
47
        representations = torch.cat([zjs, zis], dim=0)
48
        similarity_matrix = torch.mm(
49
            representations, representations.t().contiguous())
50
        corrected_similarity_matrix = similarity_matrix - \
51
            torch.eye(2*batch_size).type_as(similarity_matrix)
52
        pred_similarities, pred_indices = torch.max(
53
            corrected_similarity_matrix[:batch_size], dim=1)
54
        correct_indices = torch.arange(batch_size)+batch_size
55
        correct_preds = (
56
            pred_indices == correct_indices.type_as(pred_indices)).sum()
57
    return correct_preds.float()/batch_size
58
59
def mean(res, key1, key2=None):
60
    if key2 is not None:
61
        return torch.stack([x[key1][key2] for x in res]).mean()
62
    return torch.stack([x[key1] for x in res if type(x) == dict and key1 in x.keys()]).mean()
63
64
class Projection(nn.Module):
65
    def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=128):
66
        super().__init__()
67
        self.output_dim = output_dim
68
        self.input_dim = input_dim
69
        self.hidden_dim = hidden_dim
70
        self.model = nn.Sequential(
71
            # nn.AdaptiveAvgPool2d((1, 1)),
72
            Flatten(),
73
            nn.Linear(self.input_dim, self.hidden_dim, bias=True),
74
            # nn.BatchNorm1d(self.hidden_dim),
75
            nn.ReLU(),
76
            nn.Linear(self.hidden_dim, self.output_dim, bias=True))
77
78
    def forward(self, x):
79
        x = self.model(x)
80
        return F.normalize(x, dim=1)
81
82
83
class SyncFunction(torch.autograd.Function):
84
85
    @staticmethod
86
    def forward(ctx, tensor):
87
        ctx.batch_size = tensor.shape[0]
88
89
        gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
90
91
        torch.distributed.all_gather(gathered_tensor, tensor)
92
        gathered_tensor = torch.cat(gathered_tensor, 0)
93
94
        return gathered_tensor
95
96
    @staticmethod
97
    def backward(ctx, grad_output):
98
        grad_input = grad_output.clone()
99
        torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
100
101
        return grad_input[torch.distributed.get_rank() * ctx.batch_size:(torch.distributed.get_rank() + 1) *
102
                          ctx.batch_size]
103
104
105
class CustomSimCLR(pl.LightningModule):
106
107
    def __init__(self,
108
                 batch_size,
109
                 num_samples,
110
                 warmup_epochs=10,
111
                 lr=1e-4,
112
                 opt_weight_decay=1e-6,
113
                 loss_temperature=0.5,
114
                 config=None,
115
                 transformations=None,
116
                 **kwargs):
117
        """
118
        Args:
119
            batch_size: the batch size
120
            num_samples: num samples in the dataset
121
            warmup_epochs: epochs to warmup the lr for
122
            lr: the optimizer learning rate
123
            opt_weight_decay: the optimizer weight decay
124
            loss_temperature: the loss temperature
125
        """
126
127
        super(CustomSimCLR, self).__init__()
128
        self.config = config
129
        self.transformations = transformations
130
        self.epoch = 0
131
        self.batch_size = batch_size
132
        self.num_samples = num_samples
133
        self.save_hyperparameters()
134
        # pdb.set_trace()
135
136
    def configure_optimizers(self):
137
        global_batch_size = self.trainer.world_size * self.hparams.batch_size
138
        self.train_iters_per_epoch = self.hparams.num_samples // global_batch_size
139
        # TRICK 1 (Use lars + filter weights)
140
        # exclude certain parameters
141
        parameters = self.exclude_from_wt_decay(
142
            self.named_parameters(),
143
            weight_decay=self.hparams.opt_weight_decay
144
        )
145
146
147
        # optimizer = LARSWrapper(Adam(parameters, lr=self.hparams.lr))
148
        optimizer = Adam(parameters, lr=self.hparams.lr)
149
        
150
        # Trick 2 (after each step)
151
        self.hparams.warmup_epochs = self.hparams.warmup_epochs * self.train_iters_per_epoch
152
        max_epochs = self.trainer.max_epochs * self.train_iters_per_epoch
153
154
        linear_warmup_cosine_decay = LinearWarmupCosineAnnealingLR(
155
            optimizer,
156
            warmup_epochs=self.hparams.warmup_epochs,
157
            max_epochs=max_epochs,
158
            warmup_start_lr=0,
159
            eta_min=0
160
        )
161
162
        scheduler = {
163
            'scheduler': linear_warmup_cosine_decay,
164
            'interval': 'step',
165
            'frequency': 1
166
        }
167
168
        return [optimizer], [scheduler]
169
170
    def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=['bias', 'bn']):
171
        params = []
172
        excluded_params = []
173
174
        for name, param in named_params:
175
            if not param.requires_grad:
176
                continue
177
            elif any(layer_name in name for layer_name in skip_list):
178
                excluded_params.append(param)
179
            else:
180
                params.append(param)
181
182
        return [
183
            {'params': params, 'weight_decay': weight_decay},
184
            {'params': excluded_params, 'weight_decay': 0.}
185
        ]
186
    
187
    def shared_forward(self, batch, batch_idx):
188
        (x1, y1), (x2, y2) = batch
189
        # ENCODE
190
        # encode -> representations
191
        # (b, 3, 32, 32) -> (b, 2048, 2, 2)
192
        x1 = self.to_device(x1)
193
        x2 = self.to_device(x2)
194
195
        h1 = self.encoder(x1)[0]
196
        h2 = self.encoder(x2)[0]
197
198
        # the bolts resnets return a list of feature maps
199
        if isinstance(h1, list):
200
            h1 = h1[-1]
201
            h2 = h2[-1]
202
203
        # PROJECT
204
        # img -> E -> h -> || -> z
205
        # (b, 2048, 2, 2) -> (b, 128)
206
        z1 = self.projection(h1.squeeze())
207
        z2 = self.projection(h2.squeeze())
208
209
        return z1, z2
210
211
    def nt_xent_loss(self, out_1, out_2, temperature, eps=1e-6):
212
        """
213
            assume out_1 and out_2 are normalized
214
            out_1: [batch_size, dim]
215
            out_2: [batch_size, dim]
216
        """
217
        # gather representations in case of distributed training
218
        # out_1_dist: [batch_size * world_size, dim]
219
        # out_2_dist: [batch_size * world_size, dim]
220
        if torch.distributed.is_available() and torch.distributed.is_initialized():
221
            out_1_dist = SyncFunction.apply(out_1)
222
            out_2_dist = SyncFunction.apply(out_2)
223
            print("out dist shape: ", out_1_dist.shape)
224
        else:
225
            out_1_dist = out_1
226
            out_2_dist = out_2
227
        
228
        # out: [2 * batch_size, dim]
229
        # out_dist: [2 * batch_size * world_size, dim]
230
        out = torch.cat([out_1, out_2], dim=0)
231
        out_dist = torch.cat([out_1_dist, out_2_dist], dim=0)
232
233
        # cov and sim: [2 * batch_size, 2 * batch_size * world_size]
234
        # neg: [2 * batch_size]
235
        cov = torch.mm(out, out_dist.t().contiguous())
236
        sim = torch.exp(cov / temperature)
237
        neg = sim.sum(dim=-1)
238
239
        # from each row, subtract e^1 to remove similarity measure for x1.x1
240
        row_sub = torch.Tensor(neg.shape).fill_(math.e).to(neg.device)
241
        neg = torch.clamp(neg - row_sub, min=eps)  # clamp for numerical stability
242
243
        # Positive similarity, pos becomes [2 * batch_size]
244
        pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
245
        pos = torch.cat([pos, pos], dim=0)
246
247
        loss = -torch.log(pos / (neg + eps)).mean()
248
249
        return loss
250
251
    def training_step(self, batch, batch_idx):
252
        z1, z2 = self.shared_forward(batch, batch_idx)
253
        loss = self.nt_xent_loss(z1, z2, self.hparams.loss_temperature)
254
        # result = pl.TrainResult(minimize=loss)
255
        # result.log('train/train_loss', loss, on_epoch=True)
256
257
        acc = _accuracy(z1, z2, z1.shape[0])
258
        # result.log('train/train_acc', acc, on_epoch=True)
259
        result = {
260
            "train/train_loss": loss, 
261
            "minimize":loss,
262
            "train/train_acc" : acc,
263
        }
264
        return loss
265
266
    def validation_step(self, batch, batch_idx, dataloader_idx):
267
        if dataloader_idx != 0:
268
            return {}
269
        z1, z2 = self.shared_forward(batch, batch_idx)
270
        loss = self.nt_xent_loss(z1, z2, self.hparams.loss_temperature)
271
272
        acc = _accuracy(z1, z2, z1.shape[0])
273
        results = {
274
            'val_loss': loss,
275
            'val_acc': torch.tensor(acc)
276
        }
277
        return results
278
279
    def validation_epoch_end(self, outputs):
280
        # outputs[0] because we are using multiple datasets!
281
        val_loss = mean(outputs[0], 'val_loss')
282
        val_acc = mean(outputs[0], 'val_acc')
283
284
        log = {
285
            'val/val_loss': val_loss,
286
            'val/val_acc': val_acc
287
        }
288
        return {'val_loss': val_loss, 'log': log, 'progress_bar': log}
289
290
    def on_train_start(self):
291
        # log configuration
292
        config_str = re.sub(r"[,\}\{]", "<br/>", str(self.config))
293
        config_str = re.sub(r"[\[\]\']", "", config_str)
294
        transformation_str = re.sub(r"[\}]", "<br/>", str(["<br>" + str(
295
            t) + ":<br/>" + str(t.get_params()) for t in self.transformations]))
296
        transformation_str = re.sub(r"[,\"\{\'\[\]]", "", transformation_str)
297
        self.logger.experiment.add_text(
298
            "configuration", str(config_str), global_step=0)
299
        self.logger.experiment.add_text("transformations", str(
300
            transformation_str), global_step=0)
301
        self.epoch = 0
302
303
    def on_epoch_end(self):
304
        self.epoch += 1
305
306
    def type(self):
307
        return self.encoder.features[0][0].weight.type()
308
309
    def get_representations(self, x):
310
        return self.encoder(x)[0]
311
    
312
    def get_model(self):
313
        return self.encoder
314
315
    def get_device(self):
316
        return self.encoder.features[0][0].weight.device
317
318
    def to_device(self, x):
319
        return x.type(self.type()).to(self.get_device())
320
321
322
def parse_args(parent_parser):
323
    parser = ArgumentParser(parents=[parent_parser], add_help=False)
324
    parser.add_argument('-t', '--trafos', nargs='+', help='add transformation to data augmentation pipeline',
325
                        default=["GaussianNoise", "ChannelResize", "RandomResizedCrop"])
326
    # GaussianNoise
327
    parser.add_argument(
328
            '--gaussian_scale', help='std param for gaussian noise transformation', default=0.005, type=float)
329
    # RandomResizedCrop
330
    parser.add_argument('--rr_crop_ratio_range',
331
                            help='ratio range for random resized crop transformation', default=[0.5, 1.0], type=float)
332
    parser.add_argument(
333
            '--output_size', help='output size for random resized crop transformation', default=250, type=int)
334
    # DynamicTimeWarp
335
    parser.add_argument(
336
            '--warps', help='number of warps for dynamic time warp transformation', default=3, type=int)
337
    parser.add_argument(
338
            '--radius', help='radius of warps of dynamic time warp transformation', default=10, type=int)
339
    # TimeWarp
340
    parser.add_argument(
341
            '--epsilon', help='epsilon param for time warp', default=10, type=float)
342
    # ChannelResize
343
    parser.add_argument('--magnitude_range', nargs='+',
344
                            help='range for scale param for ChannelResize transformation', default=[0.5, 2], type=float)
345
    # Downsample
346
    parser.add_argument(
347
            '--downsample_ratio', help='downsample ratio for Downsample transformation', default=0.2, type=float)
348
    # TimeOut
349
    parser.add_argument('--to_crop_ratio_range', nargs='+',
350
                            help='ratio range for timeout transformation', default=[0.2, 0.4], type=float)
351
    # resume training
352
    parser.add_argument('--resume', action='store_true')
353
    parser.add_argument(
354
            '--gpus', help='number of gpus to use; use cpu if gpu=0', type=int, default=1)
355
    parser.add_argument(
356
            '--num_nodes', default=1, help='number of cluster nodes', type=int)
357
    parser.add_argument(
358
            '--distributed_backend', help='sets backend type')
359
    parser.add_argument('--batch_size', type=int)
360
    parser.add_argument('--epochs', type=int)
361
    parser.add_argument('--debug', action='store_true')
362
    parser.add_argument('--warm_up', default=1, type=int, help="number of warm up epochs")
363
    parser.add_argument('--precision', type=int)
364
    parser.add_argument('--datasets', dest="target_folders",
365
                            nargs='+', help='used datasets for pretraining')
366
    parser.add_argument('--log_dir', default="./experiment_logs")
367
    parser.add_argument(
368
            '--percentage', help='determines how much of the dataset shall be used during the pretraining', type=float, default=1.0)
369
    parser.add_argument('--lr', type=float, help="learning rate")
370
    parser.add_argument('--out_dim', type=int, help="output dimension of model")
371
    parser.add_argument('--filter_cinc', default=False, action="store_true", help="only valid if cinc is selected: filter out the ptb data")
372
    parser.add_argument('--base_model')
373
    parser.add_argument('--widen',type=int, help="use wide xresnet1d50")
374
    parser.add_argument('--run_callbacks', default=False, action="store_true", help="run callbacks which asses linear evaluaton and finetuning metrics during pretraining")
375
    parser.add_argument('--checkpoint_path', default="")
376
    return parser
377
378
def init_logger(config):
379
    level = logging.INFO
380
381
    if config['debug']:
382
        level = logging.DEBUG
383
384
    # remove all handlers to change basic configuration
385
    for handler in logging.root.handlers[:]:
386
        logging.root.removeHandler(handler)
387
    if not os.path.isdir(config['log_dir']):
388
        os.mkdir(config['log_dir'])
389
    logging.basicConfig(filename=os.path.join(config['log_dir'], 'info.log'), level=level,
390
                        format='%(asctime)s %(name)s:%(lineno)s %(levelname)s:  %(message)s  ')
391
    return logging.getLogger(__name__)
392
393
def pretrain_routine(args):
394
    t_params = {"gaussian_scale": args.gaussian_scale, "rr_crop_ratio_range": args.rr_crop_ratio_range, "output_size": args.output_size, "warps": args.warps, "radius": args.radius,
395
                "epsilon": args.epsilon, "magnitude_range": args.magnitude_range, "downsample_ratio": args.downsample_ratio, "to_crop_ratio_range": args.to_crop_ratio_range,
396
                "bw_cmax":0.1, "em_cmax":0.5, "pl_cmax":0.2, "bs_cmax":1}
397
    transformations = args.trafos
398
    checkpoint_config = os.path.join("checkpoints", "bolts_config.yaml")
399
    config_file = checkpoint_config if args.resume and os.path.isfile(
400
        checkpoint_config) else "bolts_config.yaml"
401
    config = yaml.load(open(config_file, "r"), Loader=yaml.FullLoader)
402
    args_dict = vars(args)
403
    for key in set(config.keys()).union(set(args_dict.keys())):
404
        config[key] = config[key] if (key not in args_dict.keys() or key in args_dict.keys(
405
        ) and key in config.keys() and args_dict[key] is None) else args_dict[key]
406
    if args.target_folders is not None:
407
        config["dataset"]["target_folders"] = args.target_folders
408
    config["dataset"]["percentage"] = args.percentage if args.percentage is not None else config["dataset"]["percentage"]
409
    config["dataset"]["filter_cinc"] = args.filter_cinc if args.filter_cinc is not None else config["dataset"]["filter_cinc"]
410
    config["model"]["base_model"] = args.base_model if args.base_model is not None else config["model"]["base_model"]
411
    config["model"]["widen"] = args.widen if args.widen is not None else config["model"]["widen"]
412
    if args.out_dim is not None:
413
        config["model"]["out_dim"] = args.out_dim
414
    init_logger(config)
415
    dataset = SimCLRDataSetWrapper(
416
        config['batch_size'], **config['dataset'], transformations=transformations, t_params=t_params)
417
    for i, t in enumerate(dataset.transformations):
418
        logger.info(str(i) + ". Transformation: " +
419
                    str(t) + ": " + str(t.get_params()))
420
    date = time.asctime()
421
    label_to_num_classes = {"label_all": 71, "label_diag": 44, "label_form": 19,
422
                            "label_rhythm": 12, "label_diag_subclass": 23, "label_diag_superclass": 5}
423
    ptb_num_classes = label_to_num_classes[config["eval_dataset"]
424
                                           ["ptb_xl_label"]]
425
    abr = {"Transpose": "Tr", "TimeOut": "TO", "DynamicTimeWarp": "DTW", "RandomResizedCrop": "RRC", "ChannelResize": "ChR", "GaussianNoise": "GN",
426
           "TimeWarp": "TW", "ToTensor": "TT", "GaussianBlur": "GB", "BaselineWander": "BlW", "PowerlineNoise": "PlN", "EMNoise": "EM", "BaselineShift": "BlS"}
427
    trs = re.sub(r"[,'\]\[]", "", str([abr[str(tr)] if abr[str(tr)] not in [
428
                 "TT", "Tr"] else '' for tr in dataset.transformations]))
429
    name = str(date) + "_" + method + "_" + str(
430
        time.time_ns())[-3:] + "_" + trs[1:]
431
    tb_logger = TensorBoardLogger(args.log_dir, name=name, version='')
432
    config["log_dir"] = os.path.join(args.log_dir, name)
433
    print(config)
434
    return config, dataset, date, transformations, t_params, ptb_num_classes, tb_logger
435
436
def aftertrain_routine(config, args, trainer, pl_model, datamodule, callbacks):
437
    scores = {}
438
    for ca in callbacks:
439
        if isinstance(ca, SSLOnlineEvaluator):
440
            scores[str(ca)] = {"macro": ca.best_macro}
441
442
    results = {"config": config, "trafos": args.trafos, "scores": scores}
443
444
    with open(os.path.join(config["log_dir"], "results.pkl"), 'wb') as handle:
445
        pickle.dump(results, handle)
446
447
    trainer.save_checkpoint(os.path.join(config["log_dir"], "checkpoints", "model.ckpt"))
448
    with open(os.path.join(config["log_dir"], "config.txt"), "w") as text_file:
449
        print(config, file=text_file)
450
451
def cli_main():
452
    from pytorch_lightning import Trainer
453
    from online_evaluator import SSLOnlineEvaluator
454
    from ecg_datamodule import ECGDataModule
455
    from clinical_ts.create_logger import create_logger
456
    from os.path import exists
457
    
458
    parser = ArgumentParser()
459
    parser = parse_args(parser)
460
    logger.info("parse arguments")
461
    args = parser.parse_args()
462
    config, dataset, date, transformations, t_params, ptb_num_classes, tb_logger = pretrain_routine(args)
463
464
    # data
465
    ecg_datamodule = ECGDataModule(config, transformations, t_params)
466
467
    callbacks = []
468
    if args.run_callbacks:
469
            # callback for online linear evaluation/fine-tuning
470
        linear_evaluator = SSLOnlineEvaluator(drop_p=0,
471
                                          z_dim=512, num_classes=ptb_num_classes, hidden_dim=None, lin_eval_epochs=config["eval_epochs"], eval_every=config["eval_every"], mode="linear_evaluation", verbose=False)
472
473
        fine_tuner = SSLOnlineEvaluator(drop_p=0,
474
                                          z_dim=512, num_classes=ptb_num_classes, hidden_dim=None, lin_eval_epochs=config["eval_epochs"], eval_every=config["eval_every"], mode="fine_tuning", verbose=False)
475
   
476
        callbacks.append(linear_evaluator)
477
        callbacks.append(fine_tuner)
478
479
    # configure trainer
480
    trainer = Trainer(logger=tb_logger, max_epochs=config["epochs"], gpus=args.gpus,
481
                      distributed_backend=args.distributed_backend, auto_lr_find=False, num_nodes=args.num_nodes, precision=config["precision"], callbacks=callbacks)
482
483
    # pytorch lightning module
484
    model = ResNetSimCLR(**config["model"])
485
    pl_model = CustomSimCLR(
486
            config["batch_size"], ecg_datamodule.num_samples, warmup_epochs=config["warm_up"], lr=config["lr"],
487
            out_dim=config["model"]["out_dim"], config=config,
488
            transformations=ecg_datamodule.transformations, loss_temperature=config["loss"]["temperature"], weight_decay=eval(config["weight_decay"]))
489
    pl_model.encoder = model
490
    pl_model.projection = Projection(
491
            input_dim=model.l1.in_features, hidden_dim=512, output_dim=config["model"]["out_dim"])
492
493
    # load checkpoint
494
    if args.checkpoint_path != "":
495
        if exists(args.checkpoint_path):
496
            logger.info("Retrieve checkpoint from " + args.checkpoint_path)
497
            pl_model.load_from_checkpoint(args.checkpoint_path)
498
        else:
499
            raise("checkpoint does not exist")
500
501
    # start training
502
    trainer.fit(pl_model, ecg_datamodule)
503
504
    aftertrain_routine(config, args, trainer, pl_model, ecg_datamodule, callbacks)
505
506
if __name__ == "__main__":  
507
    cli_main()