Diff of /custom_swav_bolts.py [000000] .. [134fd7]

Switch to unified view

a b/custom_swav_bolts.py
1
"""
2
Adapted from official swav implementation: https://github.com/facebookresearch/swav
3
"""
4
import math
5
import os
6
import re
7
from argparse import ArgumentParser
8
from typing import Callable, Optional
9
import pdb
10
import numpy as np
11
import pytorch_lightning as pl
12
import torch
13
import torch.distributed as dist
14
from pytorch_lightning.utilities import AMPType
15
from torch import nn
16
from pytorch_lightning.core.optimizer import LightningOptimizer
17
from torch.optim.optimizer import Optimizer
18
19
import yaml
20
import time
21
import logging
22
import pickle
23
# from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50
24
from pl_bolts.optimizers.lars_scheduling import LARSWrapper
25
from pl_bolts.transforms.dataset_normalizations import (
26
    cifar10_normalization,
27
    imagenet_normalization,
28
    stl10_normalization,
29
)
30
from clinical_ts.simclr_dataset_wrapper import SimCLRDataSetWrapper
31
from clinical_ts.create_logger import create_logger
32
from torchvision.models.resnet import Bottleneck, BasicBlock
33
from online_evaluator import SSLOnlineEvaluator
34
from ecg_datamodule import ECGDataModule
35
from pytorch_lightning.loggers import TensorBoardLogger
36
from models.resnet_simclr import ResNetSimCLR
37
import torchvision.transforms as transforms
38
39
_TORCHVISION_AVAILABLE = True
40
41
# import cv2
42
from typing import List
43
logger = create_logger(__name__)
44
method = "swav"
45
class SwAVTrainDataTransform(object):
46
    def __init__(
47
        self,
48
        normalize=None,
49
        size_crops: List[int] = [96, 36],
50
        nmb_crops: List[int] = [2, 4],
51
        min_scale_crops: List[float] = [0.33, 0.10],
52
        max_scale_crops: List[float] = [1, 0.33],
53
        gaussian_blur: bool = True,
54
        jitter_strength: float = 1.
55
    ):
56
        self.jitter_strength = jitter_strength
57
        self.gaussian_blur = gaussian_blur
58
59
        assert len(size_crops) == len(nmb_crops)
60
        assert len(min_scale_crops) == len(nmb_crops)
61
        assert len(max_scale_crops) == len(nmb_crops)
62
63
        self.size_crops = size_crops
64
        self.nmb_crops = nmb_crops
65
        self.min_scale_crops = min_scale_crops
66
        self.max_scale_crops = max_scale_crops
67
68
        self.color_jitter = transforms.ColorJitter(
69
            0.8 * self.jitter_strength,
70
            0.8 * self.jitter_strength,
71
            0.8 * self.jitter_strength,
72
            0.2 * self.jitter_strength
73
        )
74
75
        transform = []
76
        color_transform = [
77
            transforms.RandomApply([self.color_jitter], p=0.8),
78
            transforms.RandomGrayscale(p=0.2)
79
        ]
80
81
        if self.gaussian_blur:
82
            kernel_size = int(0.1 * self.size_crops[0])
83
            if kernel_size % 2 == 0:
84
                kernel_size += 1
85
86
            color_transform.append(
87
                GaussianBlur(kernel_size=kernel_size, p=0.5)
88
            )
89
90
        self.color_transform = transforms.Compose(color_transform)
91
92
        if normalize is None:
93
            self.final_transform = transforms.ToTensor()
94
        else:
95
            self.final_transform = transforms.Compose(
96
                [transforms.ToTensor(), normalize])
97
98
        for i in range(len(self.size_crops)):
99
            random_resized_crop = transforms.RandomResizedCrop(
100
                self.size_crops[i],
101
                scale=(self.min_scale_crops[i], self.max_scale_crops[i]),
102
            )
103
104
            transform.extend([transforms.Compose([
105
                random_resized_crop,
106
                transforms.RandomHorizontalFlip(p=0.5),
107
                self.color_transform,
108
                self.final_transform])
109
            ] * self.nmb_crops[i])
110
111
        self.transform = transform
112
113
        # add online train transform of the size of global view
114
        online_train_transform = transforms.Compose([
115
            transforms.RandomResizedCrop(self.size_crops[0]),
116
            transforms.RandomHorizontalFlip(),
117
            self.final_transform
118
        ])
119
120
        self.transform.append(online_train_transform)
121
        
122
    def __call__(self, sample):
123
        multi_crops = list(
124
            map(lambda transform: transform(sample), self.transform)
125
        )
126
        return multi_crops
127
128
129
class SwAVEvalDataTransform(SwAVTrainDataTransform):
130
    def __init__(
131
        self,
132
        normalize=None,
133
        size_crops: List[int] = [96, 36],
134
        nmb_crops: List[int] = [2, 4],
135
        min_scale_crops: List[float] = [0.33, 0.10],
136
        max_scale_crops: List[float] = [1, 0.33],
137
        gaussian_blur: bool = True,
138
        jitter_strength: float = 1.
139
    ):
140
        super().__init__(
141
            normalize=normalize,
142
            size_crops=size_crops,
143
            nmb_crops=nmb_crops,
144
            min_scale_crops=min_scale_crops,
145
            max_scale_crops=max_scale_crops,
146
            gaussian_blur=gaussian_blur,
147
            jitter_strength=jitter_strength
148
        )
149
150
        input_height = self.size_crops[0]  # get global view crop
151
        test_transform = transforms.Compose([
152
            transforms.Resize(int(input_height + 0.1 * input_height)),
153
            transforms.CenterCrop(input_height),
154
            self.final_transform,
155
        ])
156
157
        # replace last transform to eval transform in self.transform list
158
        self.transform[-1] = test_transform
159
160
161
class SwAVFinetuneTransform(object):
162
    def __init__(
163
        self,
164
        input_height: int = 224,
165
        jitter_strength: float = 1.,
166
        normalize=None,
167
        eval_transform: bool = False
168
    ) -> None:
169
170
        self.jitter_strength = jitter_strength
171
        self.input_height = input_height
172
        self.normalize = normalize
173
174
        self.color_jitter = transforms.ColorJitter(
175
            0.8 * self.jitter_strength,
176
            0.8 * self.jitter_strength,
177
            0.8 * self.jitter_strength,
178
            0.2 * self.jitter_strength
179
        )
180
181
        if not eval_transform:
182
            data_transforms = [
183
                transforms.RandomResizedCrop(size=self.input_height),
184
                transforms.RandomHorizontalFlip(p=0.5),
185
                transforms.RandomApply([self.color_jitter], p=0.8),
186
                transforms.RandomGrayscale(p=0.2)
187
            ]
188
        else:
189
            data_transforms = [
190
                transforms.Resize(
191
                    int(self.input_height + 0.1 * self.input_height)),
192
                transforms.CenterCrop(self.input_height)
193
            ]
194
195
        if normalize is None:
196
            final_transform = transforms.ToTensor()
197
        else:
198
            final_transform = transforms.Compose(
199
                [transforms.ToTensor(), normalize])
200
201
        data_transforms.append(final_transform)
202
        self.transform = transforms.Compose(data_transforms)
203
204
    def __call__(self, sample):
205
        return self.transform(sample)
206
207
208
class CustomResNet(nn.Module):
209
    def __init__(
210
            self,
211
            model,
212
            zero_init_residual=False,
213
            output_dim=16,
214
            hidden_mlp=512,
215
            nmb_prototypes=8,
216
            eval_mode=False,
217
            first_conv=True,
218
            maxpool1=True, 
219
            l2norm=True
220
    ):
221
        super(CustomResNet, self).__init__()
222
        self.l2norm = l2norm
223
        self.model = model
224
        self.features = self.model.features
225
        self.projection_head = nn.Sequential(
226
                nn.Linear(512, hidden_mlp),
227
                nn.BatchNorm1d(hidden_mlp),
228
                nn.ReLU(inplace=True),
229
                nn.Linear(hidden_mlp, output_dim),
230
            )
231
232
        # prototype layer
233
        self.prototypes = None
234
        if isinstance(nmb_prototypes, list):
235
            self.prototypes = MultiPrototypes(output_dim, nmb_prototypes)
236
        elif nmb_prototypes > 0:
237
            self.prototypes = nn.Linear(output_dim, nmb_prototypes, bias=False)
238
239
        for m in self.modules():
240
            if isinstance(m, nn.Conv2d):
241
                nn.init.kaiming_normal_(
242
                    m.weight, mode="fan_out", nonlinearity="relu")
243
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
244
                nn.init.constant_(m.weight, 1)
245
                nn.init.constant_(m.bias, 0)
246
247
        # Zero-initialize the last BN in each residual branch,
248
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
249
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
250
        if zero_init_residual:
251
            for m in self.modules():
252
                if isinstance(m, Bottleneck):
253
                    nn.init.constant_(m.bn3.weight, 0)
254
                elif isinstance(m, BasicBlock):
255
                    nn.init.constant_(m.bn2.weight, 0)
256
257
    def forward_backbone(self, x):
258
        x = x.type(self.features[0][0].weight.type())
259
        h = self.features(x)
260
        h = h.squeeze()
261
        return h
262
263
    def forward_head(self, x):
264
        if self.projection_head is not None:
265
            x = self.projection_head(x)
266
267
        if self.l2norm:
268
            x = nn.functional.normalize(x, dim=1, p=2)
269
270
        if self.prototypes is not None:
271
            return x, self.prototypes(x)
272
        return x
273
274
    def forward(self, inputs):
275
        if not isinstance(inputs, list):
276
            inputs = [inputs]
277
        idx_crops = torch.cumsum(torch.unique_consecutive(
278
            torch.tensor([inp.shape[-1] for inp in inputs]),
279
            return_counts=True,
280
        )[1], 0)
281
        start_idx = 0
282
        for end_idx in idx_crops:
283
            _out = torch.cat(inputs[start_idx: end_idx])
284
285
            if 'cuda' in str(self.features[0][0].weight.device):
286
                _out = self.forward_backbone(_out.cuda(non_blocking=True))
287
            else:
288
                _out = self.forward_backbone(_out)
289
290
            if start_idx == 0:
291
                output = _out
292
            else:
293
                output = torch.cat((output, _out))
294
            start_idx = end_idx
295
        return self.forward_head(output)
296
297
298
class MultiPrototypes(nn.Module):
299
    def __init__(self, output_dim, nmb_prototypes):
300
        super(MultiPrototypes, self).__init__()
301
        self.nmb_heads = len(nmb_prototypes)
302
        for i, k in enumerate(nmb_prototypes):
303
            self.add_module("prototypes" + str(i),
304
                            nn.Linear(output_dim, k, bias=False))
305
306
    def forward(self, x):
307
        out = []
308
        for i in range(self.nmb_heads):
309
            out.append(getattr(self, "prototypes" + str(i))(x))
310
        return out
311
312
313
class CustomSwAV(pl.LightningModule):
314
    def __init__(
315
        self,
316
        model,
317
        gpus: int,
318
        num_samples: int,
319
        batch_size: int,
320
        config=None,
321
        transformations=None,
322
        nodes: int = 1,
323
        arch: str = 'resnet50',
324
        hidden_mlp: int = 2048,
325
        feat_dim: int = 128,
326
        warmup_epochs: int = 10,
327
        max_epochs: int = 100,
328
        nmb_prototypes: int = 3000,
329
        freeze_prototypes_epochs: int = 1,
330
        temperature: float = 0.1,
331
        sinkhorn_iterations: int = 3,
332
        # queue_length: int = 512,  # must be divisible by total batch-size
333
        queue_path: str = "queue",
334
        epoch_queue_starts: int = 15,
335
        crops_for_assign: list = [0, 1],
336
        nmb_crops: list = [2, 6],
337
        first_conv: bool = True,
338
        maxpool1: bool = True,
339
        optimizer: str = 'adam',
340
        lars_wrapper: bool = False,
341
        exclude_bn_bias: bool = False,
342
        start_lr: float = 0.,
343
        learning_rate: float = 1e-3,
344
        final_lr: float = 0.,
345
        weight_decay: float = 1e-6,
346
        epsilon: float = 0.05,
347
        **kwargs
348
    ):
349
        """
350
        Args:
351
            gpus: number of gpus per node used in training, passed to SwAV module
352
                to manage the queue and select distributed sinkhorn
353
            nodes: number of nodes to train on
354
            num_samples: number of image samples used for training
355
            batch_size: batch size per GPU in ddp
356
            dataset: dataset being used for train/val
357
            arch: encoder architecture used for pre-training
358
            hidden_mlp: hidden layer of non-linear projection head, set to 0
359
                to use a linear projection head
360
            feat_dim: output dim of the projection head
361
            warmup_epochs: apply linear warmup for this many epochs
362
            max_epochs: epoch count for pre-training
363
            nmb_prototypes: count of prototype vectors
364
            freeze_prototypes_epochs: epoch till which gradients of prototype layer
365
                are frozen
366
            temperature: loss temperature
367
            sinkhorn_iterations: iterations for sinkhorn normalization
368
            queue_length: set queue when batch size is small,
369
                must be divisible by total batch-size (i.e. total_gpus * batch_size),
370
                set to 0 to remove the queue
371
            queue_path: folder within the logs directory
372
            epoch_queue_starts: start uing the queue after this epoch
373
            crops_for_assign: list of crop ids for computing assignment
374
            nmb_crops: number of global and local crops, ex: [2, 6]
375
            first_conv: keep first conv same as the original resnet architecture,
376
                if set to false it is replace by a kernel 3, stride 1 conv (cifar-10)
377
            maxpool1: keep first maxpool layer same as the original resnet architecture,
378
                if set to false, first maxpool is turned off (cifar10, maybe stl10)
379
            optimizer: optimizer to use
380
            lars_wrapper: use LARS wrapper over the optimizer
381
            exclude_bn_bias: exclude batchnorm and bias layers from weight decay in optimizers
382
            start_lr: starting lr for linear warmup
383
            learning_rate: learning rate
384
            final_lr: float = final learning rate for cosine weight decay
385
            weight_decay: weight decay for optimizer
386
            epsilon: epsilon val for swav assignments
387
        """
388
        super().__init__()
389
        # self.save_hyperparameters()
390
391
        self.epoch = 0
392
        self.config = config
393
        self.transformations = transformations
394
        self.gpus = gpus
395
        self.nodes = nodes
396
        self.arch = arch
397
        self.num_samples = num_samples
398
        self.batch_size = batch_size
399
        self.queue_length = 8*batch_size
400
401
        self.hidden_mlp = hidden_mlp
402
        self.feat_dim = feat_dim
403
        self.nmb_prototypes = nmb_prototypes
404
        self.freeze_prototypes_epochs = freeze_prototypes_epochs
405
        self.sinkhorn_iterations = sinkhorn_iterations
406
407
        #self.queue_length = queue_length
408
        self.queue_path = queue_path
409
        self.epoch_queue_starts = epoch_queue_starts
410
        self.crops_for_assign = crops_for_assign
411
        self.nmb_crops = nmb_crops
412
413
        self.first_conv = first_conv
414
        self.maxpool1 = maxpool1
415
416
        self.optim = optimizer
417
        self.lars_wrapper = lars_wrapper
418
        self.exclude_bn_bias = exclude_bn_bias
419
        self.weight_decay = weight_decay
420
        self.epsilon = epsilon
421
        self.temperature = temperature
422
423
        self.start_lr = start_lr
424
        self.final_lr = final_lr
425
        self.learning_rate = learning_rate
426
        self.warmup_epochs = warmup_epochs
427
        self.max_epochs = config["epochs"]
428
429
        if self.gpus * self.nodes > 1:
430
            self.get_assignments = self.distributed_sinkhorn
431
        else:
432
            self.get_assignments = self.sinkhorn
433
434
        
435
        
436
        # compute iters per epoch
437
        global_batch_size = self.nodes * self.gpus * \
438
            self.batch_size if self.gpus > 0 else self.batch_size
439
        self.train_iters_per_epoch = (self.num_samples // global_batch_size)+1
440
441
        # define LR schedule
442
        warmup_lr_schedule = np.linspace(
443
            self.start_lr, self.learning_rate, self.train_iters_per_epoch * self.warmup_epochs
444
        )
445
        iters = np.arange(self.train_iters_per_epoch *
446
                          (self.max_epochs - self.warmup_epochs))
447
        cosine_lr_schedule = np.array([self.final_lr + 0.5 * (self.learning_rate - self.final_lr) * (
448
            1 + math.cos(math.pi * t / (self.train_iters_per_epoch *
449
                                        (self.max_epochs - self.warmup_epochs)))
450
        ) for t in iters])
451
452
        self.lr_schedule = np.concatenate(
453
            (warmup_lr_schedule, cosine_lr_schedule))
454
        self.queue = None   
455
        self.model = self.init_model(model)
456
        self.softmax = nn.Softmax(dim=1)
457
        
458
459
    def setup(self, stage):
460
        queue_folder = os.path.join(self.config["log_dir"], self.queue_path)
461
        if not os.path.exists(queue_folder):
462
            os.makedirs(queue_folder)
463
464
        self.queue_path = os.path.join(
465
            queue_folder,
466
            "queue" + str(self.trainer.global_rank) + ".pth"
467
        )
468
469
        if os.path.isfile(self.queue_path):
470
            self.queue = torch.load(self.queue_path)["queue"]
471
        
472
    def init_model(self, model):
473
        return CustomResNet(model, hidden_mlp=self.hidden_mlp,
474
            output_dim=self.feat_dim,
475
            nmb_prototypes=self.nmb_prototypes,
476
            first_conv=self.first_conv,
477
            maxpool1=self.maxpool1)
478
479
    def forward(self, x):
480
        # pass single batch from the resnet backbone
481
        return self.model.forward_backbone(x)
482
    
483
    def on_train_start(self):
484
        # # log configuration
485
        # config_str = re.sub(r"[,\}\{]", "<br/>", str(self.config))
486
        # config_str = re.sub(r"[\[\]\']", "", config_str)
487
        # transformation_str = re.sub(r"[\}]", "<br/>", str(["<br>" + str(
488
        #     t) + ":<br/>" + str(t.get_params()) for t in self.transformations]))
489
        # transformation_str = re.sub(r"[,\"\{\'\[\]]", "", transformation_str)
490
        # self.logger.experiment.add_text(
491
        #     "configuration", str(config_str), global_step=0)
492
        # self.logger.experiment.add_text("transformations", str(
493
        #     transformation_str), global_step=0)
494
        self.epoch = 0
495
496
    def on_train_epoch_start(self):
497
        if self.queue_length > 0:
498
            if self.trainer.current_epoch >= self.epoch_queue_starts and self.queue is None:
499
                self.queue = torch.zeros(
500
                    len(self.crops_for_assign),
501
                    self.queue_length // self.gpus,  # change to nodes * gpus once multi-node
502
                    self.feat_dim,
503
                )
504
505
                if self.gpus > 0:
506
                    self.queue = self.queue.cuda()
507
508
        self.use_the_queue = False
509
510
    def on_train_epoch_end(self, outputs) -> None:
511
        if self.queue is not None:
512
            torch.save({"queue": self.queue}, self.queue_path)
513
514
    def on_epoch_end(self):
515
        self.epoch += 1
516
517
    def on_after_backward(self):
518
        if self.current_epoch < self.freeze_prototypes_epochs:
519
            for name, p in self.model.named_parameters():
520
                if "prototypes" in name:
521
                    p.grad = None
522
523
    def shared_step(self, batch):
524
        # if self.dataset == 'stl10':
525
        #     unlabeled_batch = batch[0]
526
        #     batch = unlabeled_batch
527
        
528
        
529
        inputs, y = batch
530
        # remove online train/eval transforms at this point
531
        inputs = inputs[:-1]
532
533
        # 1. normalize the prototypes
534
        with torch.no_grad():
535
            w = self.model.prototypes.weight.data.clone()
536
            w = nn.functional.normalize(w, dim=1, p=2)
537
            self.model.prototypes.weight.copy_(w)
538
539
        # 2. multi-res forward passes
540
        embedding, output = self.model(inputs)
541
        embedding = embedding.detach()
542
        bs = inputs[0].size(0)
543
544
        # 3. swav loss computation
545
        loss = 0
546
        for i, crop_id in enumerate(self.crops_for_assign):
547
            with torch.no_grad():
548
                out = output[bs * crop_id: bs * (crop_id + 1)]
549
550
                # 4. time to use the queue
551
                if self.queue is not None:
552
                    if self.use_the_queue or not torch.all(self.queue[i, -1, :] == 0):
553
                        self.use_the_queue = True
554
                        out = torch.cat((torch.mm(
555
                            self.queue[i],
556
                            self.model.prototypes.weight.t()
557
                        ), out))
558
                    # fill the queue
559
                    self.queue[i, bs:] = self.queue[i, :-bs].clone()
560
                    self.queue[i, :bs] = embedding[crop_id *
561
                                                   bs: (crop_id + 1) * bs]
562
563
                # 5. get assignments
564
                q = torch.exp(out / self.epsilon).t()
565
                q = self.get_assignments(q, self.sinkhorn_iterations)[-bs:]
566
567
            # cluster assignment prediction
568
            subloss = 0
569
            for v in np.delete(np.arange(np.sum(self.nmb_crops-1)), crop_id):
570
                p = self.softmax(
571
                    output[bs * v: bs * (v + 1)] / self.temperature)
572
                loss_value = q * torch.log(p)
573
                subloss -= torch.mean(torch.sum(loss_value, dim=1))
574
            loss += subloss / (np.sum(self.nmb_crops) - 1)
575
        loss /= len(self.crops_for_assign)
576
577
        return loss
578
579
    def training_step(self, batch, batch_idx):
580
        
581
        loss = self.shared_step(batch)
582
583
        # self.log('train_loss', loss, on_step=True, on_epoch=False)
584
        return loss
585
586
    def validation_step(self, batch, batch_idx, dataloader_idx):
587
        
588
        if dataloader_idx != 0:
589
            return {}
590
        loss = self.shared_step(batch)
591
592
        # self.log('val_loss', loss, on_step=False, on_epoch=True)
593
        results = {
594
            'val_loss': loss,
595
        }
596
        return results
597
    
598
    def validation_epoch_end(self, outputs):
599
        # outputs[0] because we are using multiple datasets!
600
        val_loss = mean(outputs[0], 'val_loss')
601
602
        log = {
603
            'val/val_loss': val_loss,
604
        }
605
        return {'val_loss': val_loss, 'log': log, 'progress_bar': log}
606
607
    def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=['bias', 'bn']):
608
        params = []
609
        excluded_params = []
610
611
        for name, param in named_params:
612
            if not param.requires_grad:
613
                continue
614
            elif any(layer_name in name for layer_name in skip_list):
615
                excluded_params.append(param)
616
            else:
617
                params.append(param)
618
619
        return [
620
            {'params': params, 'weight_decay': weight_decay},
621
            {'params': excluded_params, 'weight_decay': 0.}
622
        ]
623
624
    def configure_optimizers(self):
625
        if self.exclude_bn_bias:
626
            params = self.exclude_from_wt_decay(
627
                self.named_parameters(),
628
                weight_decay=self.weight_decay
629
            )
630
        else:
631
            params = self.parameters()
632
633
        if self.optim == 'sgd':
634
            optimizer = torch.optim.SGD(
635
                params,
636
                lr=self.learning_rate,
637
                momentum=0.9,
638
                weight_decay=self.weight_decay
639
            )
640
        elif self.optim == 'adam':
641
            optimizer = torch.optim.Adam(
642
                params,
643
                lr=self.learning_rate,
644
                weight_decay=self.weight_decay
645
            )
646
647
        if self.lars_wrapper:
648
            optimizer = LARSWrapper(
649
                optimizer,
650
                eta=0.001,  # trust coefficient
651
                clip=False
652
            )
653
654
        return optimizer
655
    
656
    def optimizer_step(
657
        self,
658
        epoch: int = None,
659
        batch_idx: int = None,
660
        optimizer: Optimizer = None,
661
        optimizer_idx: int = None,
662
        optimizer_closure: Optional[Callable] = None,
663
        on_tpu: bool = None,
664
        using_native_amp: bool = None,
665
        using_lbfgs: bool = None,
666
    ) -> None:
667
        # warm-up + decay schedule placed here since LARSWrapper is not optimizer class
668
        # adjust LR of optim contained within LARSWrapper
669
        for param_group in optimizer.param_groups:
670
            param_group["lr"] = self.lr_schedule[self.trainer.global_step]
671
672
        # from lightning
673
        if not isinstance(optimizer, LightningOptimizer):
674
            # wraps into LightingOptimizer only for running step
675
            optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer)
676
        optimizer.step(closure=optimizer_closure)
677
        
678
    def sinkhorn(self, Q, nmb_iters):
679
        with torch.no_grad():
680
            sum_Q = torch.sum(Q)
681
            Q /= sum_Q
682
683
            K, B = Q.shape
684
685
            if self.gpus > 0:
686
                u = torch.zeros(K).cuda()
687
                r = torch.ones(K).cuda() / K
688
                c = torch.ones(B).cuda() / B
689
            else:
690
                u = torch.zeros(K)
691
                r = torch.ones(K) / K
692
                c = torch.ones(B) / B
693
694
            for _ in range(nmb_iters):
695
                u = torch.sum(Q, dim=1)
696
697
                Q *= (r / u).unsqueeze(1)
698
                Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)
699
700
            return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float()
701
702
    def distributed_sinkhorn(self, Q, nmb_iters):
703
        with torch.no_grad():
704
            sum_Q = torch.sum(Q)
705
            dist.all_reduce(sum_Q)
706
            Q /= sum_Q
707
708
            if self.gpus > 0:
709
                u = torch.zeros(Q.shape[0]).cuda(non_blocking=True)
710
                r = torch.ones(Q.shape[0]).cuda(non_blocking=True) / Q.shape[0]
711
                c = torch.ones(Q.shape[1]).cuda(
712
                    non_blocking=True) / (self.gpus * Q.shape[1])
713
            else:
714
                u = torch.zeros(Q.shape[0])
715
                r = torch.ones(Q.shape[0]) / Q.shape[0]
716
                c = torch.ones(Q.shape[1]) / (self.gpus * Q.shape[1])
717
718
            curr_sum = torch.sum(Q, dim=1)
719
            dist.all_reduce(curr_sum)
720
721
            for it in range(nmb_iters):
722
                u = curr_sum
723
                Q *= (r / u).unsqueeze(1)
724
                Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)
725
                curr_sum = torch.sum(Q, dim=1)
726
                dist.all_reduce(curr_sum)
727
            return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float()
728
729
    def type(self):
730
        return self.model.features[0][0].weight.type()
731
732
    def get_representations(self, x):
733
        return self.model.features(x)
734
735
    def get_model(self):
736
        return self.model.model
737
        
738
    def get_device(self):
739
        return self.model.features[0][0].weight.device
740
741
    @staticmethod
742
    def add_model_specific_args(parent_parser):
743
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
744
745
        # model params
746
        parser.add_argument("--arch", default="resnet50",
747
                            type=str, help="convnet architecture")
748
        # specify flags to store false
749
        parser.add_argument("--first_conv", action='store_false')
750
        parser.add_argument("--maxpool1", action='store_false')
751
        parser.add_argument("--hidden_mlp", default=2048, type=int,
752
                            help="hidden layer dimension in projection head")
753
        parser.add_argument("--feat_dim", default=128,
754
                            type=int, help="feature dimension")
755
        parser.add_argument("--online_ft", action='store_true')
756
        parser.add_argument("--fp32", action='store_true')
757
758
        # transform params
759
        parser.add_argument("--gaussian_blur",
760
                            action="store_true", help="add gaussian blur")
761
        parser.add_argument("--jitter_strength", type=float,
762
                            default=1.0, help="jitter strength")
763
        parser.add_argument("--dataset", type=str,
764
                            default="stl10", help="stl10, cifar10")
765
        parser.add_argument("--data_dir", type=str,
766
                            default=".", help="path to download data")
767
        parser.add_argument("--queue_path", type=str,
768
                            default="queue", help="path for queue")
769
770
        parser.add_argument("--nmb_crops", type=int, default=[2, 4], nargs="+",
771
                            help="list of number of crops (example: [2, 6])")
772
        parser.add_argument("--size_crops", type=int, default=[96, 36], nargs="+",
773
                            help="crops resolutions (example: [224, 96])")
774
        parser.add_argument("--min_scale_crops", type=float, default=[0.33, 0.10], nargs="+",
775
                            help="argument in RandomResizedCrop (example: [0.14, 0.05])")
776
        parser.add_argument("--max_scale_crops", type=float, default=[1, 0.33], nargs="+",
777
                            help="argument in RandomResizedCrop (example: [1., 0.14])")
778
779
        # training params
780
        parser.add_argument("--fast_dev_run", action='store_true')
781
        parser.add_argument("--nodes", default=1, type=int,
782
                            help="number of nodes for training")
783
        parser.add_argument("--gpus", default=1, type=int,
784
                            help="number of gpus to train on")
785
        parser.add_argument("--num_workers", default=8,
786
                            type=int, help="num of workers per GPU")
787
        parser.add_argument("--optimizer", default="adam",
788
                            type=str, help="choose between adam/sgd")
789
        parser.add_argument("--lars_wrapper", action='store_true',
790
                            help="apple lars wrapper over optimizer used")
791
        parser.add_argument('--exclude_bn_bias', action='store_true',
792
                            help="exclude bn/bias from weight decay")
793
        parser.add_argument("--max_epochs", default=100,
794
                            type=int, help="number of total epochs to run")
795
        parser.add_argument("--max_steps", default=-1,
796
                            type=int, help="max steps")
797
        parser.add_argument("--warmup_epochs", default=10,
798
                            type=int, help="number of warmup epochs")
799
        parser.add_argument("--batch_size", default=128,
800
                            type=int, help="batch size per gpu")
801
802
        parser.add_argument("--weight_decay", default=1e-6,
803
                            type=float, help="weight decay")
804
        parser.add_argument("--learning_rate", default=1e-3,
805
                            type=float, help="base learning rate")
806
        parser.add_argument("--start_lr", default=0, type=float,
807
                            help="initial warmup learning rate")
808
        parser.add_argument("--final_lr", type=float,
809
                            default=1e-6, help="final learning rate")
810
811
        # swav params
812
        parser.add_argument("--crops_for_assign", type=int, nargs="+", default=[0, 1],
813
                            help="list of crops id used for computing assignments")
814
        parser.add_argument("--temperature", default=0.1, type=float,
815
                            help="temperature parameter in training loss")
816
        parser.add_argument("--epsilon", default=0.05, type=float,
817
                            help="regularization parameter for Sinkhorn-Knopp algorithm")
818
        parser.add_argument("--sinkhorn_iterations", default=3, type=int,
819
                            help="number of iterations in Sinkhorn-Knopp algorithm")
820
        parser.add_argument("--nmb_prototypes", default=512,
821
                            type=int, help="number of prototypes")
822
        parser.add_argument("--queue_length", type=int, default=0,
823
                            help="length of the queue (0 for no queue); must be divisible by total batch size")
824
        parser.add_argument("--epoch_queue_starts", type=int, default=15,
825
                            help="from this epoch, we start using a queue")
826
        parser.add_argument("--freeze_prototypes_epochs", default=1, type=int,
827
                            help="freeze the prototypes during this many epochs from the start")
828
829
        return parser
830
831
832
def mean(res, key1, key2=None):
833
    if key2 is not None:
834
        return torch.stack([x[key1][key2] for x in res]).mean()
835
    return torch.stack([x[key1] for x in res if type(x) == dict and key1 in x.keys()]).mean()
836
837
def parse_args(parent_parser):
838
    parser = ArgumentParser(parents=[parent_parser], add_help=False)
839
    parser.add_argument('-t', '--trafos', nargs='+', help='add transformation to data augmentation pipeline',
840
                        default=["GaussianNoise", "ChannelResize", "RandomResizedCrop"])
841
    # GaussianNoise
842
    parser.add_argument(
843
            '--gaussian_scale', help='std param for gaussian noise transformation', default=0.005, type=float)
844
    # RandomResizedCrop
845
    parser.add_argument('--rr_crop_ratio_range',
846
                            help='ratio range for random resized crop transformation', default=[0.5, 1.0], type=float)
847
    parser.add_argument(
848
            '--output_size', help='output size for random resized crop transformation', default=250, type=int)
849
    # DynamicTimeWarp
850
    parser.add_argument(
851
            '--warps', help='number of warps for dynamic time warp transformation', default=3, type=int)
852
    parser.add_argument(
853
            '--radius', help='radius of warps of dynamic time warp transformation', default=10, type=int)
854
    # TimeWarp
855
    parser.add_argument(
856
            '--epsilon', help='epsilon param for time warp', default=10, type=float)
857
    # ChannelResize
858
    parser.add_argument('--magnitude_range', nargs='+',
859
                            help='range for scale param for ChannelResize transformation', default=[0.5, 2], type=float)
860
    # Downsample
861
    parser.add_argument(
862
            '--downsample_ratio', help='downsample ratio for Downsample transformation', default=0.2, type=float)
863
    # TimeOut
864
    parser.add_argument('--to_crop_ratio_range', nargs='+',
865
                            help='ratio range for timeout transformation', default=[0.2, 0.4], type=float)
866
    # resume training
867
    parser.add_argument('--resume', action='store_true')
868
    parser.add_argument(
869
            '--gpus', help='number of gpus to use; use cpu if gpu=0', type=int, default=1)
870
    parser.add_argument(
871
            '--num_nodes', default=1, help='number of cluster nodes', type=int)
872
    parser.add_argument(
873
            '--distributed_backend', help='sets backend type')
874
    parser.add_argument('--batch_size', type=int)
875
    parser.add_argument('--epochs', type=int)
876
    parser.add_argument('--debug', action='store_true')
877
    parser.add_argument('--warm_up', default=1, type=int)
878
    parser.add_argument('--precision', type=int)
879
    parser.add_argument('--datasets', dest="target_folders",
880
                            nargs='+', help='used datasets for pretraining')
881
    parser.add_argument('--log_dir', default="./experiment_logs")
882
    parser.add_argument(
883
            '--percentage', help='determines how much of the dataset shall be used during the pretraining', type=float, default=1.0)
884
    parser.add_argument('--lr', type=float, help="learning rate")
885
    parser.add_argument('--out_dim', type=int, help="output dimension of model")
886
    parser.add_argument('--filter_cinc', default=False, action="store_true", help="only valid if cinc is selected: filter out the ptb data")
887
    parser.add_argument('--base_model')
888
    parser.add_argument('--widen',type=int, help="use wide xresnet1d50")
889
    parser.add_argument('--run_callbacks', default=False, action="store_true", help="run callbacks which asses linear evaluaton and finetuning metrics during pretraining")
890
891
    parser.add_argument('--checkpoint_path', default="")
892
    return parser
893
894
def init_logger(config):
895
    level = logging.INFO
896
897
    if config['debug']:
898
        level = logging.DEBUG
899
900
    # remove all handlers to change basic configuration
901
    for handler in logging.root.handlers[:]:
902
        logging.root.removeHandler(handler)
903
    if not os.path.isdir(config['log_dir']):
904
        os.mkdir(config['log_dir'])
905
    logging.basicConfig(filename=os.path.join(config['log_dir'], 'info.log'), level=level,
906
                        format='%(asctime)s %(name)s:%(lineno)s %(levelname)s:  %(message)s  ')
907
    return logging.getLogger(__name__)
908
909
def pretrain_routine(args):
910
    t_params = {"gaussian_scale": args.gaussian_scale, "rr_crop_ratio_range": args.rr_crop_ratio_range, "output_size": args.output_size, "warps": args.warps, "radius": args.radius,
911
                "epsilon": args.epsilon, "magnitude_range": args.magnitude_range, "downsample_ratio": args.downsample_ratio, "to_crop_ratio_range": args.to_crop_ratio_range,
912
                "bw_cmax":0.1, "em_cmax":0.5, "pl_cmax":0.2, "bs_cmax":1}
913
    transformations = args.trafos
914
    checkpoint_config = os.path.join("checkpoints", "bolts_config.yaml")
915
    config_file = checkpoint_config if args.resume and os.path.isfile(
916
        checkpoint_config) else "bolts_config.yaml"
917
    config = yaml.load(open(config_file, "r"), Loader=yaml.FullLoader)
918
    args_dict = vars(args)
919
    for key in set(config.keys()).union(set(args_dict.keys())):
920
        config[key] = config[key] if (key not in args_dict.keys() or key in args_dict.keys(
921
        ) and key in config.keys() and args_dict[key] is None) else args_dict[key]
922
    if args.target_folders is not None:
923
        config["dataset"]["target_folders"] = args.target_folders
924
    config["dataset"]["percentage"] = args.percentage if args.percentage is not None else config["dataset"]["percentage"]
925
    config["dataset"]["filter_cinc"] = args.filter_cinc if args.filter_cinc is not None else config["dataset"]["filter_cinc"]
926
    config["model"]["base_model"] = args.base_model if args.base_model is not None else config["model"]["base_model"]
927
    config["model"]["widen"] = args.widen if args.widen is not None else config["model"]["widen"]
928
    config["dataset"]["swav"] = True
929
    config["dataset"]["nmb_crops"] = 7
930
    config["eval_dataset"]["swav"] = True
931
    config["eval_dataset"]["nmb_crops"] = 7
932
    if args.out_dim is not None:
933
        config["model"]["out_dim"] = args.out_dim
934
    init_logger(config)
935
    dataset = SimCLRDataSetWrapper(
936
        config['batch_size'], **config['dataset'], transformations=transformations, t_params=t_params)
937
    for i, t in enumerate(dataset.transformations):
938
        logger.info(str(i) + ". Transformation: " +
939
                    str(t) + ": " + str(t.get_params()))
940
    date = time.asctime()
941
    label_to_num_classes = {"label_all": 71, "label_diag": 44, "label_form": 19,
942
                            "label_rhythm": 12, "label_diag_subclass": 23, "label_diag_superclass": 5}
943
    ptb_num_classes = label_to_num_classes[config["eval_dataset"]
944
                                           ["ptb_xl_label"]]
945
    abr = {"Transpose": "Tr", "TimeOut": "TO", "DynamicTimeWarp": "DTW", "RandomResizedCrop": "RRC", "ChannelResize": "ChR", "GaussianNoise": "GN",
946
           "TimeWarp": "TW", "ToTensor": "TT", "GaussianBlur": "GB", "BaselineWander": "BlW", "PowerlineNoise": "PlN", "EMNoise": "EM", "BaselineShift": "BlS"}
947
    trs = re.sub(r"[,'\]\[]", "", str([abr[str(tr)] if abr[str(tr)] not in [
948
                 "TT", "Tr"] else '' for tr in dataset.transformations]))
949
    name = str(date) + "_" + method + "_" + str(
950
        time.time_ns())[-3:] + "_" + trs[1:]
951
    tb_logger = TensorBoardLogger(args.log_dir, name=name, version='')
952
    config["log_dir"] = os.path.join(args.log_dir, name)
953
    print(config)
954
    return config, dataset, date, transformations, t_params, ptb_num_classes, tb_logger
955
956
def aftertrain_routine(config, args, trainer, pl_model, datamodule, callbacks):
957
    scores = {}
958
    for ca in callbacks:
959
        if isinstance(ca, SSLOnlineEvaluator):
960
            scores[str(ca)] = {"macro": ca.best_macro}
961
962
    results = {"config": config, "trafos": args.trafos, "scores": scores}
963
964
    with open(os.path.join(config["log_dir"], "results.pkl"), 'wb') as handle:
965
        pickle.dump(results, handle)
966
967
    trainer.save_checkpoint(os.path.join(config["log_dir"], "checkpoints", "model.ckpt"))
968
    with open(os.path.join(config["log_dir"], "config.txt"), "w") as text_file:
969
        print(config, file=text_file)
970
971
def cli_main():
972
    from pytorch_lightning import Trainer
973
    from online_evaluator import SSLOnlineEvaluator
974
    from ecg_datamodule import ECGDataModule
975
    from clinical_ts.create_logger import create_logger
976
    from os.path import exists
977
    
978
    parser = ArgumentParser()
979
    parser = parse_args(parser)
980
    logger.info("parse arguments")
981
    args = parser.parse_args()
982
    config, dataset, date, transformations, t_params, ptb_num_classes, tb_logger = pretrain_routine(args)
983
984
    # data
985
    ecg_datamodule = ECGDataModule(config, transformations, t_params)
986
987
    callbacks = []
988
    if args.run_callbacks:
989
            # callback for online linear evaluation/fine-tuning
990
        linear_evaluator = SSLOnlineEvaluator(drop_p=0,
991
                                          z_dim=512, num_classes=ptb_num_classes, hidden_dim=None, lin_eval_epochs=config["eval_epochs"], eval_every=config["eval_every"], mode="linear_evaluation", verbose=False)
992
993
        fine_tuner = SSLOnlineEvaluator(drop_p=0,
994
                                          z_dim=512, num_classes=ptb_num_classes, hidden_dim=None, lin_eval_epochs=config["eval_epochs"], eval_every=config["eval_every"], mode="fine_tuning", verbose=False)
995
   
996
        callbacks.append(linear_evaluator)
997
        callbacks.append(fine_tuner)
998
999
    # configure trainer
1000
    trainer = Trainer(logger=tb_logger, max_epochs=config["epochs"], gpus=args.gpus,
1001
                      distributed_backend=args.distributed_backend, auto_lr_find=False, num_nodes=args.num_nodes, precision=config["precision"], callbacks=callbacks)
1002
1003
    # pytorch lightning module
1004
    model = ResNetSimCLR(**config["model"])
1005
    pl_model = CustomSwAV(model,  config["gpus"], ecg_datamodule.num_samples, config["batch_size"], config=config,
1006
                              transformations=ecg_datamodule.transformations, nmb_crops=config["dataset"]["nmb_crops"])
1007
    # load checkpoint
1008
    if args.checkpoint_path != "":
1009
        if exists(args.checkpoint_path):
1010
            logger.info("Retrieve checkpoint from " + args.checkpoint_path)
1011
            pl_model.load_from_checkpoint(args.checkpoint_path)
1012
        else:
1013
            raise("checkpoint does not exist")
1014
1015
    # start training
1016
    trainer.fit(pl_model, ecg_datamodule)
1017
1018
    aftertrain_routine(config, args, trainer, pl_model, ecg_datamodule, callbacks)
1019
1020
if __name__ == "__main__":  
1021
    cli_main()