Diff of /custom_byol_bolts.py [000000] .. [134fd7]

Switch to unified view

a b/custom_byol_bolts.py
1
import math
2
from argparse import ArgumentParser
3
from copy import deepcopy
4
from typing import Any
5
6
import pytorch_lightning as pl
7
import torch
8
import torch.nn as nn
9
import torch.nn.functional as F
10
from pytorch_lightning import seed_everything
11
from torch.optim import Adam
12
13
from pl_bolts.models.self_supervised import BYOL
14
# from pl_bolts.callbacks.self_supervised import BYOLMAWeightUpdate
15
from pl_bolts.optimizers.lars_scheduling import LARSWrapper
16
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
17
18
from models.resnet_simclr import ResNetSimCLR
19
import re
20
21
import time
22
23
import yaml
24
import logging
25
import os
26
from clinical_ts.simclr_dataset_wrapper import SimCLRDataSetWrapper
27
from clinical_ts.create_logger import create_logger
28
import pickle
29
from pytorch_lightning import Trainer, seed_everything
30
31
from torch import nn
32
from torch.nn import functional as F
33
from online_evaluator import SSLOnlineEvaluator
34
from ecg_datamodule import ECGDataModule
35
from pytorch_lightning.loggers import TensorBoardLogger
36
import pdb
37
38
logger = create_logger(__name__)
39
method="byol"
40
def mean(res, key1, key2=None):
41
    if key2 is not None:
42
        return torch.stack([x[key1][key2] for x in res]).mean()
43
    return torch.stack([x[key1] for x in res if type(x) == dict and key1 in x.keys()]).mean()
44
45
class MLP(nn.Module):
46
    def __init__(self, input_dim=512, hidden_size=4096, output_dim=256):
47
        super().__init__()
48
        self.output_dim = output_dim
49
        self.input_dim = input_dim
50
        self.model = nn.Sequential(
51
            nn.Linear(input_dim, hidden_size, bias=False),
52
            nn.BatchNorm1d(hidden_size),
53
            nn.ReLU(inplace=True),
54
            nn.Linear(hidden_size, output_dim, bias=True))
55
56
    def forward(self, x):
57
        x = self.model(x)
58
        return x
59
60
61
class SiameseArm(nn.Module):
62
    def __init__(self, encoder=None, out_dim=128, hidden_size=512, projector_dim=512):
63
        super().__init__()
64
65
        if encoder is None:
66
            encoder = torchvision_ssl_encoder('resnet50')
67
        # Encoder
68
        self.encoder = encoder
69
        # Pooler
70
        self.pooler = nn.AdaptiveAvgPool2d((1, 1))
71
        # Projector
72
        projector_dim = encoder.l1.in_features
73
        self.projector = MLP(
74
            input_dim=projector_dim, hidden_size=hidden_size, output_dim=out_dim)
75
        # Predictor
76
        self.predictor = MLP(
77
            input_dim=out_dim, hidden_size=hidden_size, output_dim=out_dim)
78
79
    def forward(self, x):
80
        y = self.encoder(x)[0]
81
        y = y.view(y.size(0), -1)
82
        z = self.projector(y)
83
        h = self.predictor(z)
84
        return y, z, h
85
86
87
class BYOLMAWeightUpdate(pl.Callback):
88
    def __init__(self, initial_tau=0.996):
89
        """
90
        Args:
91
            initial_tau: starting tau. Auto-updates with every training step
92
        """
93
        super().__init__()
94
        self.initial_tau = initial_tau
95
        self.current_tau = initial_tau
96
97
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
98
        # get networks
99
        online_net = pl_module.online_network
100
        target_net = pl_module.target_network
101
102
        # update weights
103
        self.update_weights(online_net, target_net)
104
105
        # update tau after
106
        self.current_tau = self.update_tau(pl_module, trainer)
107
108
    def update_tau(self, pl_module, trainer):
109
        max_steps = len(trainer.train_dataloader) * trainer.max_epochs
110
        tau = 1 - (1 - self.initial_tau) * (math.cos(math.pi *
111
                                                     pl_module.global_step / max_steps) + 1) / 2
112
        return tau
113
114
    def update_weights(self, online_net, target_net):
115
        # apply MA weight update
116
        for (name, online_p), (_, target_p) in zip(online_net.named_parameters(), target_net.named_parameters()):
117
            if 'weight' in name:
118
                target_p.data = self.current_tau * target_p.data + \
119
                    (1 - self.current_tau) * online_p.data
120
121
122
class CustomBYOL(pl.LightningModule):
123
    def __init__(self,
124
                 num_classes=5,
125
                 learning_rate: float = 0.2,
126
                 weight_decay: float = 1.5e-6,
127
                 input_height: int = 32,
128
                 batch_size: int = 32,
129
                 num_workers: int = 0,
130
                 warmup_epochs: int = 10,
131
                 max_epochs: int = 1000,
132
                 config=None,
133
                 transformations=None,
134
                 **kwargs):
135
        """
136
        Args:
137
            datamodule: The datamodule
138
            learning_rate: the learning rate
139
            weight_decay: optimizer weight decay
140
            input_height: image input height
141
            batch_size: the batch size
142
            num_workers: number of workers
143
            warmup_epochs: num of epochs for scheduler warm up
144
            max_epochs: max epochs for scheduler
145
        """
146
        super().__init__()
147
        self.save_hyperparameters()
148
149
        self.config = config
150
        self.transformations = transformations
151
        self.online_network = SiameseArm(
152
            encoder=self.init_model(), out_dim=config["model"]["out_dim"])
153
        self.target_network = deepcopy(self.online_network)
154
        self.weight_callback = BYOLMAWeightUpdate()
155
        self.log_dict = {}
156
        self.epoch = 0
157
        # self.model_device = self.online_network.encoder.features[0][0].weight.device
158
159
    def init_model(self):
160
        model = ResNetSimCLR(**self.config["model"])
161
        # return model.features
162
        return model
163
164
    # def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
165
    def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
166
        # Add callback for user automatically since it's key to BYOL weight update
167
        self.weight_callback.on_train_batch_end(
168
            self.trainer, self, outputs, batch, batch_idx, 0)
169
170
    def forward(self, x):
171
        y, _, _ = self.online_network(x)
172
        return y
173
174
    def cosine_similarity(self, a, b):
175
        a = F.normalize(a, dim=-1)
176
        b = F.normalize(b, dim=-1)
177
        sim = (a * b).sum(-1).mean()
178
        return sim
179
180
    def shared_step(self, batch, batch_idx):
181
        # (img_1, img_2), y = batch
182
        (img_1, y1), (img_2, y2) = batch
183
184
        img_1 = self.to_device(img_1)
185
        img_2 = self.to_device(img_2)
186
187
        # Image 1 to image 2 loss
188
        y1, z1, h1 = self.online_network(img_1)
189
        with torch.no_grad():
190
            y2, z2, h2 = self.target_network(img_2)
191
        loss_a = - 2 * self.cosine_similarity(h1, z2)
192
193
        # Image 2 to image 1 loss
194
        y1, z1, h1 = self.online_network(img_2)
195
        with torch.no_grad():
196
            y2, z2, h2 = self.target_network(img_1)
197
        # L2 normalize
198
        loss_b = - 2 * self.cosine_similarity(h1, z2)
199
200
        # Final loss
201
        total_loss = loss_a + loss_b
202
203
        return loss_a, loss_b, total_loss
204
205
    def training_step(self, batch, batch_idx):
206
        loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx)
207
208
        # log results
209
        # result = pl.TrainResult(minimize=total_loss)
210
        # result.log('train_loss/1_2_loss', loss_a, on_epoch=True)
211
        # result.log('train_loss/2_1_loss', loss_b, on_epoch=True)
212
        # result.log('train_loss/total_loss', total_loss, on_epoch=True)
213
214
        # # log results
215
        # self.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b,
216
        #               'train_loss': total_loss})
217
218
        return total_loss
219
220
    def validation_step(self, batch, batch_idx, dataloader_idx):
221
        if dataloader_idx != 0:
222
            return {}
223
224
        loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx)
225
226
        # # log results
227
        # result = pl.EvalResult()
228
        # result.log('val_loss/1_2_loss', loss_a, on_epoch=True)
229
        # result.log('val_loss/2_1_loss', loss_b, on_epoch=True)
230
        # result.log('val_loss/total_loss', total_loss, on_epoch=True)
231
232
        # self.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b,
233
        #               'train_loss': total_loss})
234
        results = {
235
            'val_loss': total_loss,
236
            'val_1_2_loss' : loss_a,
237
            'val_2_1_loss': loss_b
238
        }
239
        return results
240
    
241
    def validation_epoch_end(self, outputs):
242
        # outputs[0] because we are using multiple datasets!
243
        val_loss = mean(outputs[0], 'val_loss')
244
        loss_a = mean(outputs[0], 'val_1_2_loss')
245
        loss_b = mean(outputs[0], 'val_2_1_loss')
246
247
        log = {
248
            'val_loss': val_loss,
249
            'val_1_2_loss' : loss_a,
250
            'val_2_1_loss': loss_b
251
        }
252
253
        return {'val_loss': val_loss, 'log': log, 'progress_bar': log}
254
    
255
    def configure_optimizers(self):
256
        optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate,
257
                         weight_decay=self.hparams.weight_decay)
258
        # optimizer = LARSWrapper(optimizer)
259
        optimizer = optimizer
260
        scheduler = LinearWarmupCosineAnnealingLR(
261
            optimizer,
262
            warmup_epochs=self.hparams.warmup_epochs,
263
            max_epochs=self.hparams.max_epochs
264
        )
265
        return [optimizer], [scheduler]
266
267
    def on_train_start(self):
268
        # log configuration
269
        config_str = re.sub(r"[,\}\{]", "<br/>", str(self.config))
270
        config_str = re.sub(r"[\[\]\']", "", config_str)
271
        transformation_str = re.sub(r"[\}]", "<br/>", str(["<br>" + str(
272
            t) + ":<br/>" + str(t.get_params()) for t in self.transformations]))
273
        transformation_str = re.sub(r"[,\"\{\'\[\]]", "", transformation_str)
274
        self.logger.experiment.add_text(
275
            "configuration", str(config_str), global_step=0)
276
        self.logger.experiment.add_text("transformations", str(
277
            transformation_str), global_step=0)
278
        self.epoch = 0
279
280
    def on_epoch_end(self):
281
        self.epoch += 1
282
283
    def get_representations(self, x):
284
        return self.online_network(x)[0]
285
286
    def get_model(self):
287
        return self.online_network.encoder
288
289
    def get_device(self):
290
        return self.online_network.encoder.features[0][0].weight.device
291
292
    def to_device(self, x):
293
        return x.type(self.type()).to(self.get_device())
294
295
    def type(self):
296
        return self.online_network.encoder.features[0][0].weight.type()
297
298
def parse_args(parent_parser):
299
    parser = ArgumentParser(parents=[parent_parser], add_help=False)
300
    parser.add_argument('-t', '--trafos', nargs='+', help='add transformation to data augmentation pipeline',
301
                        default=["GaussianNoise", "ChannelResize", "RandomResizedCrop"])
302
    # GaussianNoise
303
    parser.add_argument(
304
            '--gaussian_scale', help='std param for gaussian noise transformation', default=0.005, type=float)
305
    # RandomResizedCrop
306
    parser.add_argument('--rr_crop_ratio_range',
307
                            help='ratio range for random resized crop transformation', default=[0.5, 1.0], type=float)
308
    parser.add_argument(
309
            '--output_size', help='output size for random resized crop transformation', default=250, type=int)
310
    # DynamicTimeWarp
311
    parser.add_argument(
312
            '--warps', help='number of warps for dynamic time warp transformation', default=3, type=int)
313
    parser.add_argument(
314
            '--radius', help='radius of warps of dynamic time warp transformation', default=10, type=int)
315
    # TimeWarp
316
    parser.add_argument(
317
            '--epsilon', help='epsilon param for time warp', default=10, type=float)
318
    # ChannelResize
319
    parser.add_argument('--magnitude_range', nargs='+',
320
                            help='range for scale param for ChannelResize transformation', default=[0.5, 2], type=float)
321
    # Downsample
322
    parser.add_argument(
323
            '--downsample_ratio', help='downsample ratio for Downsample transformation', default=0.2, type=float)
324
    # TimeOut
325
    parser.add_argument('--to_crop_ratio_range', nargs='+',
326
                            help='ratio range for timeout transformation', default=[0.2, 0.4], type=float)
327
    # resume training
328
    parser.add_argument('--resume', action='store_true')
329
    parser.add_argument(
330
            '--gpus', help='number of gpus to use; use cpu if gpu=0', type=int, default=1)
331
    parser.add_argument(
332
            '--num_nodes', default=1, help='number of cluster nodes', type=int)
333
    parser.add_argument(
334
            '--distributed_backend', help='sets backend type')
335
    parser.add_argument('--batch_size', type=int)
336
    parser.add_argument('--epochs', type=int)
337
    parser.add_argument('--debug', action='store_true')
338
    parser.add_argument('--warm_up', default=1, type=int)
339
    parser.add_argument('--precision', type=int)
340
    parser.add_argument('--datasets', dest="target_folders",
341
                            nargs='+', help='used datasets for pretraining')
342
    parser.add_argument('--log_dir', default="./experiment_logs")
343
    parser.add_argument(
344
            '--percentage', help='determines how much of the dataset shall be used during the pretraining', type=float, default=1.0)
345
    parser.add_argument('--lr', type=float, help="learning rate")
346
    parser.add_argument('--out_dim', type=int, help="output dimension of model")
347
    parser.add_argument('--filter_cinc', default=False, action="store_true", help="only valid if cinc is selected: filter out the ptb data")
348
    parser.add_argument('--base_model')
349
    parser.add_argument('--widen',type=int, help="use wide xresnet1d50")
350
    parser.add_argument('--run_callbacks', default=False, action="store_true", help="run callbacks which asses linear evaluaton and finetuning metrics during pretraining")
351
352
    parser.add_argument('--checkpoint_path', default="")
353
    return parser
354
355
def init_logger(config):
356
    level = logging.INFO
357
358
    if config['debug']:
359
        level = logging.DEBUG
360
361
    # remove all handlers to change basic configuration
362
    for handler in logging.root.handlers[:]:
363
        logging.root.removeHandler(handler)
364
    if not os.path.isdir(config['log_dir']):
365
        os.mkdir(config['log_dir'])
366
    logging.basicConfig(filename=os.path.join(config['log_dir'], 'info.log'), level=level,
367
                        format='%(asctime)s %(name)s:%(lineno)s %(levelname)s:  %(message)s  ')
368
    return logging.getLogger(__name__)
369
370
def pretrain_routine(args):
371
    t_params = {"gaussian_scale": args.gaussian_scale, "rr_crop_ratio_range": args.rr_crop_ratio_range, "output_size": args.output_size, "warps": args.warps, "radius": args.radius,
372
                "epsilon": args.epsilon, "magnitude_range": args.magnitude_range, "downsample_ratio": args.downsample_ratio, "to_crop_ratio_range": args.to_crop_ratio_range,
373
                "bw_cmax":0.1, "em_cmax":0.5, "pl_cmax":0.2, "bs_cmax":1}
374
    transformations = args.trafos
375
    checkpoint_config = os.path.join("checkpoints", "bolts_config.yaml")
376
    config_file = checkpoint_config if args.resume and os.path.isfile(
377
        checkpoint_config) else "bolts_config.yaml"
378
    config = yaml.load(open(config_file, "r"), Loader=yaml.FullLoader)
379
    args_dict = vars(args)
380
    for key in set(config.keys()).union(set(args_dict.keys())):
381
        config[key] = config[key] if (key not in args_dict.keys() or key in args_dict.keys(
382
        ) and key in config.keys() and args_dict[key] is None) else args_dict[key]
383
    if args.target_folders is not None:
384
        config["dataset"]["target_folders"] = args.target_folders
385
    config["dataset"]["percentage"] = args.percentage if args.percentage is not None else config["dataset"]["percentage"]
386
    config["dataset"]["filter_cinc"] = args.filter_cinc if args.filter_cinc is not None else config["dataset"]["filter_cinc"]
387
    config["model"]["base_model"] = args.base_model if args.base_model is not None else config["model"]["base_model"]
388
    config["model"]["widen"] = args.widen if args.widen is not None else config["model"]["widen"]
389
    if args.out_dim is not None:
390
        config["model"]["out_dim"] = args.out_dim
391
    init_logger(config)
392
    dataset = SimCLRDataSetWrapper(
393
        config['batch_size'], **config['dataset'], transformations=transformations, t_params=t_params)
394
    for i, t in enumerate(dataset.transformations):
395
        logger.info(str(i) + ". Transformation: " +
396
                    str(t) + ": " + str(t.get_params()))
397
    date = time.asctime()
398
    label_to_num_classes = {"label_all": 71, "label_diag": 44, "label_form": 19,
399
                            "label_rhythm": 12, "label_diag_subclass": 23, "label_diag_superclass": 5}
400
    ptb_num_classes = label_to_num_classes[config["eval_dataset"]
401
                                           ["ptb_xl_label"]]
402
    abr = {"Transpose": "Tr", "TimeOut": "TO", "DynamicTimeWarp": "DTW", "RandomResizedCrop": "RRC", "ChannelResize": "ChR", "GaussianNoise": "GN",
403
           "TimeWarp": "TW", "ToTensor": "TT", "GaussianBlur": "GB", "BaselineWander": "BlW", "PowerlineNoise": "PlN", "EMNoise": "EM", "BaselineShift": "BlS"}
404
    trs = re.sub(r"[,'\]\[]", "", str([abr[str(tr)] if abr[str(tr)] not in [
405
                 "TT", "Tr"] else '' for tr in dataset.transformations]))
406
    name = str(date) + "_" + method + "_" + str(
407
        time.time_ns())[-3:] + "_" + trs[1:]
408
    tb_logger = TensorBoardLogger(args.log_dir, name=name, version='')
409
    config["log_dir"] = os.path.join(args.log_dir, name)
410
    print(config)
411
    return config, dataset, date, transformations, t_params, ptb_num_classes, tb_logger
412
413
def aftertrain_routine(config, args, trainer, pl_model, datamodule, callbacks):
414
    scores = {}
415
    for ca in callbacks:
416
        if isinstance(ca, SSLOnlineEvaluator):
417
            scores[str(ca)] = {"macro": ca.best_macro}
418
419
    results = {"config": config, "trafos": args.trafos, "scores": scores}
420
421
    with open(os.path.join(config["log_dir"], "results.pkl"), 'wb') as handle:
422
        pickle.dump(results, handle)
423
424
    trainer.save_checkpoint(os.path.join(config["log_dir"], "checkpoints", "model.ckpt"))
425
    with open(os.path.join(config["log_dir"], "config.txt"), "w") as text_file:
426
        print(config, file=text_file)
427
428
def cli_main():
429
    from pytorch_lightning import Trainer
430
    from online_evaluator import SSLOnlineEvaluator
431
    from ecg_datamodule import ECGDataModule
432
    from clinical_ts.create_logger import create_logger
433
    from os.path import exists
434
    
435
    parser = ArgumentParser()
436
    parser = parse_args(parser)
437
    logger.info("parse arguments")
438
    args = parser.parse_args()
439
    config, dataset, date, transformations, t_params, ptb_num_classes, tb_logger = pretrain_routine(args)
440
441
    # data
442
    ecg_datamodule = ECGDataModule(config, transformations, t_params)
443
444
    callbacks = []
445
    if args.run_callbacks:
446
            # callback for online linear evaluation/fine-tuning
447
        linear_evaluator = SSLOnlineEvaluator(drop_p=0,
448
                                          z_dim=512, num_classes=ptb_num_classes, hidden_dim=None, lin_eval_epochs=config["eval_epochs"], eval_every=config["eval_every"], mode="linear_evaluation", verbose=False)
449
450
        fine_tuner = SSLOnlineEvaluator(drop_p=0,
451
                                          z_dim=512, num_classes=ptb_num_classes, hidden_dim=None, lin_eval_epochs=config["eval_epochs"], eval_every=config["eval_every"], mode="fine_tuning", verbose=False)
452
   
453
        callbacks.append(linear_evaluator)
454
        callbacks.append(fine_tuner)
455
456
    # configure trainer
457
    trainer = Trainer(logger=tb_logger, max_epochs=config["epochs"], gpus=args.gpus,
458
                      distributed_backend=args.distributed_backend, auto_lr_find=False, num_nodes=args.num_nodes, precision=config["precision"], callbacks=callbacks)
459
460
    # pytorch lightning module
461
    pl_model = CustomBYOL(5, learning_rate=config["lr"], weight_decay=eval(config["weight_decay"]),
462
                              warm_up_epochs=config["warm_up"], max_epochs=config[
463
                                  "epochs"], num_workers=config["dataset"]["num_workers"],
464
                              batch_size=config["batch_size"], config=config, transformations=ecg_datamodule.transformations)
465
466
467
    # load checkpoint
468
    if args.checkpoint_path != "":
469
        if exists(args.checkpoint_path):
470
            logger.info("Retrieve checkpoint from " + args.checkpoint_path)
471
            pl_model.load_from_checkpoint(args.checkpoint_path)
472
        else:
473
            raise("checkpoint does not exist")
474
475
    # start training
476
    trainer.fit(pl_model, ecg_datamodule)
477
478
    aftertrain_routine(config, args, trainer, pl_model, ecg_datamodule, callbacks)
479
480
if __name__ == "__main__":
481
    cli_main()