--- a
+++ b/custom_swav_bolts.py
@@ -0,0 +1,1021 @@
+"""
+Adapted from official swav implementation: https://github.com/facebookresearch/swav
+"""
+import math
+import os
+import re
+from argparse import ArgumentParser
+from typing import Callable, Optional
+import pdb
+import numpy as np
+import pytorch_lightning as pl
+import torch
+import torch.distributed as dist
+from pytorch_lightning.utilities import AMPType
+from torch import nn
+from pytorch_lightning.core.optimizer import LightningOptimizer
+from torch.optim.optimizer import Optimizer
+
+import yaml
+import time
+import logging
+import pickle
+# from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50
+from pl_bolts.optimizers.lars_scheduling import LARSWrapper
+from pl_bolts.transforms.dataset_normalizations import (
+    cifar10_normalization,
+    imagenet_normalization,
+    stl10_normalization,
+)
+from clinical_ts.simclr_dataset_wrapper import SimCLRDataSetWrapper
+from clinical_ts.create_logger import create_logger
+from torchvision.models.resnet import Bottleneck, BasicBlock
+from online_evaluator import SSLOnlineEvaluator
+from ecg_datamodule import ECGDataModule
+from pytorch_lightning.loggers import TensorBoardLogger
+from models.resnet_simclr import ResNetSimCLR
+import torchvision.transforms as transforms
+
+_TORCHVISION_AVAILABLE = True
+
+# import cv2
+from typing import List
+logger = create_logger(__name__)
+method = "swav"
+class SwAVTrainDataTransform(object):
+    def __init__(
+        self,
+        normalize=None,
+        size_crops: List[int] = [96, 36],
+        nmb_crops: List[int] = [2, 4],
+        min_scale_crops: List[float] = [0.33, 0.10],
+        max_scale_crops: List[float] = [1, 0.33],
+        gaussian_blur: bool = True,
+        jitter_strength: float = 1.
+    ):
+        self.jitter_strength = jitter_strength
+        self.gaussian_blur = gaussian_blur
+
+        assert len(size_crops) == len(nmb_crops)
+        assert len(min_scale_crops) == len(nmb_crops)
+        assert len(max_scale_crops) == len(nmb_crops)
+
+        self.size_crops = size_crops
+        self.nmb_crops = nmb_crops
+        self.min_scale_crops = min_scale_crops
+        self.max_scale_crops = max_scale_crops
+
+        self.color_jitter = transforms.ColorJitter(
+            0.8 * self.jitter_strength,
+            0.8 * self.jitter_strength,
+            0.8 * self.jitter_strength,
+            0.2 * self.jitter_strength
+        )
+
+        transform = []
+        color_transform = [
+            transforms.RandomApply([self.color_jitter], p=0.8),
+            transforms.RandomGrayscale(p=0.2)
+        ]
+
+        if self.gaussian_blur:
+            kernel_size = int(0.1 * self.size_crops[0])
+            if kernel_size % 2 == 0:
+                kernel_size += 1
+
+            color_transform.append(
+                GaussianBlur(kernel_size=kernel_size, p=0.5)
+            )
+
+        self.color_transform = transforms.Compose(color_transform)
+
+        if normalize is None:
+            self.final_transform = transforms.ToTensor()
+        else:
+            self.final_transform = transforms.Compose(
+                [transforms.ToTensor(), normalize])
+
+        for i in range(len(self.size_crops)):
+            random_resized_crop = transforms.RandomResizedCrop(
+                self.size_crops[i],
+                scale=(self.min_scale_crops[i], self.max_scale_crops[i]),
+            )
+
+            transform.extend([transforms.Compose([
+                random_resized_crop,
+                transforms.RandomHorizontalFlip(p=0.5),
+                self.color_transform,
+                self.final_transform])
+            ] * self.nmb_crops[i])
+
+        self.transform = transform
+
+        # add online train transform of the size of global view
+        online_train_transform = transforms.Compose([
+            transforms.RandomResizedCrop(self.size_crops[0]),
+            transforms.RandomHorizontalFlip(),
+            self.final_transform
+        ])
+
+        self.transform.append(online_train_transform)
+        
+    def __call__(self, sample):
+        multi_crops = list(
+            map(lambda transform: transform(sample), self.transform)
+        )
+        return multi_crops
+
+
+class SwAVEvalDataTransform(SwAVTrainDataTransform):
+    def __init__(
+        self,
+        normalize=None,
+        size_crops: List[int] = [96, 36],
+        nmb_crops: List[int] = [2, 4],
+        min_scale_crops: List[float] = [0.33, 0.10],
+        max_scale_crops: List[float] = [1, 0.33],
+        gaussian_blur: bool = True,
+        jitter_strength: float = 1.
+    ):
+        super().__init__(
+            normalize=normalize,
+            size_crops=size_crops,
+            nmb_crops=nmb_crops,
+            min_scale_crops=min_scale_crops,
+            max_scale_crops=max_scale_crops,
+            gaussian_blur=gaussian_blur,
+            jitter_strength=jitter_strength
+        )
+
+        input_height = self.size_crops[0]  # get global view crop
+        test_transform = transforms.Compose([
+            transforms.Resize(int(input_height + 0.1 * input_height)),
+            transforms.CenterCrop(input_height),
+            self.final_transform,
+        ])
+
+        # replace last transform to eval transform in self.transform list
+        self.transform[-1] = test_transform
+
+
+class SwAVFinetuneTransform(object):
+    def __init__(
+        self,
+        input_height: int = 224,
+        jitter_strength: float = 1.,
+        normalize=None,
+        eval_transform: bool = False
+    ) -> None:
+
+        self.jitter_strength = jitter_strength
+        self.input_height = input_height
+        self.normalize = normalize
+
+        self.color_jitter = transforms.ColorJitter(
+            0.8 * self.jitter_strength,
+            0.8 * self.jitter_strength,
+            0.8 * self.jitter_strength,
+            0.2 * self.jitter_strength
+        )
+
+        if not eval_transform:
+            data_transforms = [
+                transforms.RandomResizedCrop(size=self.input_height),
+                transforms.RandomHorizontalFlip(p=0.5),
+                transforms.RandomApply([self.color_jitter], p=0.8),
+                transforms.RandomGrayscale(p=0.2)
+            ]
+        else:
+            data_transforms = [
+                transforms.Resize(
+                    int(self.input_height + 0.1 * self.input_height)),
+                transforms.CenterCrop(self.input_height)
+            ]
+
+        if normalize is None:
+            final_transform = transforms.ToTensor()
+        else:
+            final_transform = transforms.Compose(
+                [transforms.ToTensor(), normalize])
+
+        data_transforms.append(final_transform)
+        self.transform = transforms.Compose(data_transforms)
+
+    def __call__(self, sample):
+        return self.transform(sample)
+
+
+class CustomResNet(nn.Module):
+    def __init__(
+            self,
+            model,
+            zero_init_residual=False,
+            output_dim=16,
+            hidden_mlp=512,
+            nmb_prototypes=8,
+            eval_mode=False,
+            first_conv=True,
+            maxpool1=True, 
+            l2norm=True
+    ):
+        super(CustomResNet, self).__init__()
+        self.l2norm = l2norm
+        self.model = model
+        self.features = self.model.features
+        self.projection_head = nn.Sequential(
+                nn.Linear(512, hidden_mlp),
+                nn.BatchNorm1d(hidden_mlp),
+                nn.ReLU(inplace=True),
+                nn.Linear(hidden_mlp, output_dim),
+            )
+
+        # prototype layer
+        self.prototypes = None
+        if isinstance(nmb_prototypes, list):
+            self.prototypes = MultiPrototypes(output_dim, nmb_prototypes)
+        elif nmb_prototypes > 0:
+            self.prototypes = nn.Linear(output_dim, nmb_prototypes, bias=False)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(
+                    m.weight, mode="fan_out", nonlinearity="relu")
+            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+        # Zero-initialize the last BN in each residual branch,
+        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+        if zero_init_residual:
+            for m in self.modules():
+                if isinstance(m, Bottleneck):
+                    nn.init.constant_(m.bn3.weight, 0)
+                elif isinstance(m, BasicBlock):
+                    nn.init.constant_(m.bn2.weight, 0)
+
+    def forward_backbone(self, x):
+        x = x.type(self.features[0][0].weight.type())
+        h = self.features(x)
+        h = h.squeeze()
+        return h
+
+    def forward_head(self, x):
+        if self.projection_head is not None:
+            x = self.projection_head(x)
+
+        if self.l2norm:
+            x = nn.functional.normalize(x, dim=1, p=2)
+
+        if self.prototypes is not None:
+            return x, self.prototypes(x)
+        return x
+
+    def forward(self, inputs):
+        if not isinstance(inputs, list):
+            inputs = [inputs]
+        idx_crops = torch.cumsum(torch.unique_consecutive(
+            torch.tensor([inp.shape[-1] for inp in inputs]),
+            return_counts=True,
+        )[1], 0)
+        start_idx = 0
+        for end_idx in idx_crops:
+            _out = torch.cat(inputs[start_idx: end_idx])
+
+            if 'cuda' in str(self.features[0][0].weight.device):
+                _out = self.forward_backbone(_out.cuda(non_blocking=True))
+            else:
+                _out = self.forward_backbone(_out)
+
+            if start_idx == 0:
+                output = _out
+            else:
+                output = torch.cat((output, _out))
+            start_idx = end_idx
+        return self.forward_head(output)
+
+
+class MultiPrototypes(nn.Module):
+    def __init__(self, output_dim, nmb_prototypes):
+        super(MultiPrototypes, self).__init__()
+        self.nmb_heads = len(nmb_prototypes)
+        for i, k in enumerate(nmb_prototypes):
+            self.add_module("prototypes" + str(i),
+                            nn.Linear(output_dim, k, bias=False))
+
+    def forward(self, x):
+        out = []
+        for i in range(self.nmb_heads):
+            out.append(getattr(self, "prototypes" + str(i))(x))
+        return out
+
+
+class CustomSwAV(pl.LightningModule):
+    def __init__(
+        self,
+        model,
+        gpus: int,
+        num_samples: int,
+        batch_size: int,
+        config=None,
+        transformations=None,
+        nodes: int = 1,
+        arch: str = 'resnet50',
+        hidden_mlp: int = 2048,
+        feat_dim: int = 128,
+        warmup_epochs: int = 10,
+        max_epochs: int = 100,
+        nmb_prototypes: int = 3000,
+        freeze_prototypes_epochs: int = 1,
+        temperature: float = 0.1,
+        sinkhorn_iterations: int = 3,
+        # queue_length: int = 512,  # must be divisible by total batch-size
+        queue_path: str = "queue",
+        epoch_queue_starts: int = 15,
+        crops_for_assign: list = [0, 1],
+        nmb_crops: list = [2, 6],
+        first_conv: bool = True,
+        maxpool1: bool = True,
+        optimizer: str = 'adam',
+        lars_wrapper: bool = False,
+        exclude_bn_bias: bool = False,
+        start_lr: float = 0.,
+        learning_rate: float = 1e-3,
+        final_lr: float = 0.,
+        weight_decay: float = 1e-6,
+        epsilon: float = 0.05,
+        **kwargs
+    ):
+        """
+        Args:
+            gpus: number of gpus per node used in training, passed to SwAV module
+                to manage the queue and select distributed sinkhorn
+            nodes: number of nodes to train on
+            num_samples: number of image samples used for training
+            batch_size: batch size per GPU in ddp
+            dataset: dataset being used for train/val
+            arch: encoder architecture used for pre-training
+            hidden_mlp: hidden layer of non-linear projection head, set to 0
+                to use a linear projection head
+            feat_dim: output dim of the projection head
+            warmup_epochs: apply linear warmup for this many epochs
+            max_epochs: epoch count for pre-training
+            nmb_prototypes: count of prototype vectors
+            freeze_prototypes_epochs: epoch till which gradients of prototype layer
+                are frozen
+            temperature: loss temperature
+            sinkhorn_iterations: iterations for sinkhorn normalization
+            queue_length: set queue when batch size is small,
+                must be divisible by total batch-size (i.e. total_gpus * batch_size),
+                set to 0 to remove the queue
+            queue_path: folder within the logs directory
+            epoch_queue_starts: start uing the queue after this epoch
+            crops_for_assign: list of crop ids for computing assignment
+            nmb_crops: number of global and local crops, ex: [2, 6]
+            first_conv: keep first conv same as the original resnet architecture,
+                if set to false it is replace by a kernel 3, stride 1 conv (cifar-10)
+            maxpool1: keep first maxpool layer same as the original resnet architecture,
+                if set to false, first maxpool is turned off (cifar10, maybe stl10)
+            optimizer: optimizer to use
+            lars_wrapper: use LARS wrapper over the optimizer
+            exclude_bn_bias: exclude batchnorm and bias layers from weight decay in optimizers
+            start_lr: starting lr for linear warmup
+            learning_rate: learning rate
+            final_lr: float = final learning rate for cosine weight decay
+            weight_decay: weight decay for optimizer
+            epsilon: epsilon val for swav assignments
+        """
+        super().__init__()
+        # self.save_hyperparameters()
+
+        self.epoch = 0
+        self.config = config
+        self.transformations = transformations
+        self.gpus = gpus
+        self.nodes = nodes
+        self.arch = arch
+        self.num_samples = num_samples
+        self.batch_size = batch_size
+        self.queue_length = 8*batch_size
+
+        self.hidden_mlp = hidden_mlp
+        self.feat_dim = feat_dim
+        self.nmb_prototypes = nmb_prototypes
+        self.freeze_prototypes_epochs = freeze_prototypes_epochs
+        self.sinkhorn_iterations = sinkhorn_iterations
+
+        #self.queue_length = queue_length
+        self.queue_path = queue_path
+        self.epoch_queue_starts = epoch_queue_starts
+        self.crops_for_assign = crops_for_assign
+        self.nmb_crops = nmb_crops
+
+        self.first_conv = first_conv
+        self.maxpool1 = maxpool1
+
+        self.optim = optimizer
+        self.lars_wrapper = lars_wrapper
+        self.exclude_bn_bias = exclude_bn_bias
+        self.weight_decay = weight_decay
+        self.epsilon = epsilon
+        self.temperature = temperature
+
+        self.start_lr = start_lr
+        self.final_lr = final_lr
+        self.learning_rate = learning_rate
+        self.warmup_epochs = warmup_epochs
+        self.max_epochs = config["epochs"]
+
+        if self.gpus * self.nodes > 1:
+            self.get_assignments = self.distributed_sinkhorn
+        else:
+            self.get_assignments = self.sinkhorn
+
+        
+        
+        # compute iters per epoch
+        global_batch_size = self.nodes * self.gpus * \
+            self.batch_size if self.gpus > 0 else self.batch_size
+        self.train_iters_per_epoch = (self.num_samples // global_batch_size)+1
+
+        # define LR schedule
+        warmup_lr_schedule = np.linspace(
+            self.start_lr, self.learning_rate, self.train_iters_per_epoch * self.warmup_epochs
+        )
+        iters = np.arange(self.train_iters_per_epoch *
+                          (self.max_epochs - self.warmup_epochs))
+        cosine_lr_schedule = np.array([self.final_lr + 0.5 * (self.learning_rate - self.final_lr) * (
+            1 + math.cos(math.pi * t / (self.train_iters_per_epoch *
+                                        (self.max_epochs - self.warmup_epochs)))
+        ) for t in iters])
+
+        self.lr_schedule = np.concatenate(
+            (warmup_lr_schedule, cosine_lr_schedule))
+        self.queue = None   
+        self.model = self.init_model(model)
+        self.softmax = nn.Softmax(dim=1)
+        
+
+    def setup(self, stage):
+        queue_folder = os.path.join(self.config["log_dir"], self.queue_path)
+        if not os.path.exists(queue_folder):
+            os.makedirs(queue_folder)
+
+        self.queue_path = os.path.join(
+            queue_folder,
+            "queue" + str(self.trainer.global_rank) + ".pth"
+        )
+
+        if os.path.isfile(self.queue_path):
+            self.queue = torch.load(self.queue_path)["queue"]
+        
+    def init_model(self, model):
+        return CustomResNet(model, hidden_mlp=self.hidden_mlp,
+            output_dim=self.feat_dim,
+            nmb_prototypes=self.nmb_prototypes,
+            first_conv=self.first_conv,
+            maxpool1=self.maxpool1)
+
+    def forward(self, x):
+        # pass single batch from the resnet backbone
+        return self.model.forward_backbone(x)
+    
+    def on_train_start(self):
+        # # log configuration
+        # config_str = re.sub(r"[,\}\{]", "<br/>", str(self.config))
+        # config_str = re.sub(r"[\[\]\']", "", config_str)
+        # transformation_str = re.sub(r"[\}]", "<br/>", str(["<br>" + str(
+        #     t) + ":<br/>" + str(t.get_params()) for t in self.transformations]))
+        # transformation_str = re.sub(r"[,\"\{\'\[\]]", "", transformation_str)
+        # self.logger.experiment.add_text(
+        #     "configuration", str(config_str), global_step=0)
+        # self.logger.experiment.add_text("transformations", str(
+        #     transformation_str), global_step=0)
+        self.epoch = 0
+
+    def on_train_epoch_start(self):
+        if self.queue_length > 0:
+            if self.trainer.current_epoch >= self.epoch_queue_starts and self.queue is None:
+                self.queue = torch.zeros(
+                    len(self.crops_for_assign),
+                    self.queue_length // self.gpus,  # change to nodes * gpus once multi-node
+                    self.feat_dim,
+                )
+
+                if self.gpus > 0:
+                    self.queue = self.queue.cuda()
+
+        self.use_the_queue = False
+
+    def on_train_epoch_end(self, outputs) -> None:
+        if self.queue is not None:
+            torch.save({"queue": self.queue}, self.queue_path)
+
+    def on_epoch_end(self):
+        self.epoch += 1
+
+    def on_after_backward(self):
+        if self.current_epoch < self.freeze_prototypes_epochs:
+            for name, p in self.model.named_parameters():
+                if "prototypes" in name:
+                    p.grad = None
+
+    def shared_step(self, batch):
+        # if self.dataset == 'stl10':
+        #     unlabeled_batch = batch[0]
+        #     batch = unlabeled_batch
+        
+        
+        inputs, y = batch
+        # remove online train/eval transforms at this point
+        inputs = inputs[:-1]
+
+        # 1. normalize the prototypes
+        with torch.no_grad():
+            w = self.model.prototypes.weight.data.clone()
+            w = nn.functional.normalize(w, dim=1, p=2)
+            self.model.prototypes.weight.copy_(w)
+
+        # 2. multi-res forward passes
+        embedding, output = self.model(inputs)
+        embedding = embedding.detach()
+        bs = inputs[0].size(0)
+
+        # 3. swav loss computation
+        loss = 0
+        for i, crop_id in enumerate(self.crops_for_assign):
+            with torch.no_grad():
+                out = output[bs * crop_id: bs * (crop_id + 1)]
+
+                # 4. time to use the queue
+                if self.queue is not None:
+                    if self.use_the_queue or not torch.all(self.queue[i, -1, :] == 0):
+                        self.use_the_queue = True
+                        out = torch.cat((torch.mm(
+                            self.queue[i],
+                            self.model.prototypes.weight.t()
+                        ), out))
+                    # fill the queue
+                    self.queue[i, bs:] = self.queue[i, :-bs].clone()
+                    self.queue[i, :bs] = embedding[crop_id *
+                                                   bs: (crop_id + 1) * bs]
+
+                # 5. get assignments
+                q = torch.exp(out / self.epsilon).t()
+                q = self.get_assignments(q, self.sinkhorn_iterations)[-bs:]
+
+            # cluster assignment prediction
+            subloss = 0
+            for v in np.delete(np.arange(np.sum(self.nmb_crops-1)), crop_id):
+                p = self.softmax(
+                    output[bs * v: bs * (v + 1)] / self.temperature)
+                loss_value = q * torch.log(p)
+                subloss -= torch.mean(torch.sum(loss_value, dim=1))
+            loss += subloss / (np.sum(self.nmb_crops) - 1)
+        loss /= len(self.crops_for_assign)
+
+        return loss
+
+    def training_step(self, batch, batch_idx):
+        
+        loss = self.shared_step(batch)
+
+        # self.log('train_loss', loss, on_step=True, on_epoch=False)
+        return loss
+
+    def validation_step(self, batch, batch_idx, dataloader_idx):
+        
+        if dataloader_idx != 0:
+            return {}
+        loss = self.shared_step(batch)
+
+        # self.log('val_loss', loss, on_step=False, on_epoch=True)
+        results = {
+            'val_loss': loss,
+        }
+        return results
+    
+    def validation_epoch_end(self, outputs):
+        # outputs[0] because we are using multiple datasets!
+        val_loss = mean(outputs[0], 'val_loss')
+
+        log = {
+            'val/val_loss': val_loss,
+        }
+        return {'val_loss': val_loss, 'log': log, 'progress_bar': log}
+
+    def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=['bias', 'bn']):
+        params = []
+        excluded_params = []
+
+        for name, param in named_params:
+            if not param.requires_grad:
+                continue
+            elif any(layer_name in name for layer_name in skip_list):
+                excluded_params.append(param)
+            else:
+                params.append(param)
+
+        return [
+            {'params': params, 'weight_decay': weight_decay},
+            {'params': excluded_params, 'weight_decay': 0.}
+        ]
+
+    def configure_optimizers(self):
+        if self.exclude_bn_bias:
+            params = self.exclude_from_wt_decay(
+                self.named_parameters(),
+                weight_decay=self.weight_decay
+            )
+        else:
+            params = self.parameters()
+
+        if self.optim == 'sgd':
+            optimizer = torch.optim.SGD(
+                params,
+                lr=self.learning_rate,
+                momentum=0.9,
+                weight_decay=self.weight_decay
+            )
+        elif self.optim == 'adam':
+            optimizer = torch.optim.Adam(
+                params,
+                lr=self.learning_rate,
+                weight_decay=self.weight_decay
+            )
+
+        if self.lars_wrapper:
+            optimizer = LARSWrapper(
+                optimizer,
+                eta=0.001,  # trust coefficient
+                clip=False
+            )
+
+        return optimizer
+    
+    def optimizer_step(
+        self,
+        epoch: int = None,
+        batch_idx: int = None,
+        optimizer: Optimizer = None,
+        optimizer_idx: int = None,
+        optimizer_closure: Optional[Callable] = None,
+        on_tpu: bool = None,
+        using_native_amp: bool = None,
+        using_lbfgs: bool = None,
+    ) -> None:
+        # warm-up + decay schedule placed here since LARSWrapper is not optimizer class
+        # adjust LR of optim contained within LARSWrapper
+        for param_group in optimizer.param_groups:
+            param_group["lr"] = self.lr_schedule[self.trainer.global_step]
+
+        # from lightning
+        if not isinstance(optimizer, LightningOptimizer):
+            # wraps into LightingOptimizer only for running step
+            optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer)
+        optimizer.step(closure=optimizer_closure)
+        
+    def sinkhorn(self, Q, nmb_iters):
+        with torch.no_grad():
+            sum_Q = torch.sum(Q)
+            Q /= sum_Q
+
+            K, B = Q.shape
+
+            if self.gpus > 0:
+                u = torch.zeros(K).cuda()
+                r = torch.ones(K).cuda() / K
+                c = torch.ones(B).cuda() / B
+            else:
+                u = torch.zeros(K)
+                r = torch.ones(K) / K
+                c = torch.ones(B) / B
+
+            for _ in range(nmb_iters):
+                u = torch.sum(Q, dim=1)
+
+                Q *= (r / u).unsqueeze(1)
+                Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)
+
+            return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float()
+
+    def distributed_sinkhorn(self, Q, nmb_iters):
+        with torch.no_grad():
+            sum_Q = torch.sum(Q)
+            dist.all_reduce(sum_Q)
+            Q /= sum_Q
+
+            if self.gpus > 0:
+                u = torch.zeros(Q.shape[0]).cuda(non_blocking=True)
+                r = torch.ones(Q.shape[0]).cuda(non_blocking=True) / Q.shape[0]
+                c = torch.ones(Q.shape[1]).cuda(
+                    non_blocking=True) / (self.gpus * Q.shape[1])
+            else:
+                u = torch.zeros(Q.shape[0])
+                r = torch.ones(Q.shape[0]) / Q.shape[0]
+                c = torch.ones(Q.shape[1]) / (self.gpus * Q.shape[1])
+
+            curr_sum = torch.sum(Q, dim=1)
+            dist.all_reduce(curr_sum)
+
+            for it in range(nmb_iters):
+                u = curr_sum
+                Q *= (r / u).unsqueeze(1)
+                Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)
+                curr_sum = torch.sum(Q, dim=1)
+                dist.all_reduce(curr_sum)
+            return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float()
+
+    def type(self):
+        return self.model.features[0][0].weight.type()
+
+    def get_representations(self, x):
+        return self.model.features(x)
+
+    def get_model(self):
+        return self.model.model
+        
+    def get_device(self):
+        return self.model.features[0][0].weight.device
+
+    @staticmethod
+    def add_model_specific_args(parent_parser):
+        parser = ArgumentParser(parents=[parent_parser], add_help=False)
+
+        # model params
+        parser.add_argument("--arch", default="resnet50",
+                            type=str, help="convnet architecture")
+        # specify flags to store false
+        parser.add_argument("--first_conv", action='store_false')
+        parser.add_argument("--maxpool1", action='store_false')
+        parser.add_argument("--hidden_mlp", default=2048, type=int,
+                            help="hidden layer dimension in projection head")
+        parser.add_argument("--feat_dim", default=128,
+                            type=int, help="feature dimension")
+        parser.add_argument("--online_ft", action='store_true')
+        parser.add_argument("--fp32", action='store_true')
+
+        # transform params
+        parser.add_argument("--gaussian_blur",
+                            action="store_true", help="add gaussian blur")
+        parser.add_argument("--jitter_strength", type=float,
+                            default=1.0, help="jitter strength")
+        parser.add_argument("--dataset", type=str,
+                            default="stl10", help="stl10, cifar10")
+        parser.add_argument("--data_dir", type=str,
+                            default=".", help="path to download data")
+        parser.add_argument("--queue_path", type=str,
+                            default="queue", help="path for queue")
+
+        parser.add_argument("--nmb_crops", type=int, default=[2, 4], nargs="+",
+                            help="list of number of crops (example: [2, 6])")
+        parser.add_argument("--size_crops", type=int, default=[96, 36], nargs="+",
+                            help="crops resolutions (example: [224, 96])")
+        parser.add_argument("--min_scale_crops", type=float, default=[0.33, 0.10], nargs="+",
+                            help="argument in RandomResizedCrop (example: [0.14, 0.05])")
+        parser.add_argument("--max_scale_crops", type=float, default=[1, 0.33], nargs="+",
+                            help="argument in RandomResizedCrop (example: [1., 0.14])")
+
+        # training params
+        parser.add_argument("--fast_dev_run", action='store_true')
+        parser.add_argument("--nodes", default=1, type=int,
+                            help="number of nodes for training")
+        parser.add_argument("--gpus", default=1, type=int,
+                            help="number of gpus to train on")
+        parser.add_argument("--num_workers", default=8,
+                            type=int, help="num of workers per GPU")
+        parser.add_argument("--optimizer", default="adam",
+                            type=str, help="choose between adam/sgd")
+        parser.add_argument("--lars_wrapper", action='store_true',
+                            help="apple lars wrapper over optimizer used")
+        parser.add_argument('--exclude_bn_bias', action='store_true',
+                            help="exclude bn/bias from weight decay")
+        parser.add_argument("--max_epochs", default=100,
+                            type=int, help="number of total epochs to run")
+        parser.add_argument("--max_steps", default=-1,
+                            type=int, help="max steps")
+        parser.add_argument("--warmup_epochs", default=10,
+                            type=int, help="number of warmup epochs")
+        parser.add_argument("--batch_size", default=128,
+                            type=int, help="batch size per gpu")
+
+        parser.add_argument("--weight_decay", default=1e-6,
+                            type=float, help="weight decay")
+        parser.add_argument("--learning_rate", default=1e-3,
+                            type=float, help="base learning rate")
+        parser.add_argument("--start_lr", default=0, type=float,
+                            help="initial warmup learning rate")
+        parser.add_argument("--final_lr", type=float,
+                            default=1e-6, help="final learning rate")
+
+        # swav params
+        parser.add_argument("--crops_for_assign", type=int, nargs="+", default=[0, 1],
+                            help="list of crops id used for computing assignments")
+        parser.add_argument("--temperature", default=0.1, type=float,
+                            help="temperature parameter in training loss")
+        parser.add_argument("--epsilon", default=0.05, type=float,
+                            help="regularization parameter for Sinkhorn-Knopp algorithm")
+        parser.add_argument("--sinkhorn_iterations", default=3, type=int,
+                            help="number of iterations in Sinkhorn-Knopp algorithm")
+        parser.add_argument("--nmb_prototypes", default=512,
+                            type=int, help="number of prototypes")
+        parser.add_argument("--queue_length", type=int, default=0,
+                            help="length of the queue (0 for no queue); must be divisible by total batch size")
+        parser.add_argument("--epoch_queue_starts", type=int, default=15,
+                            help="from this epoch, we start using a queue")
+        parser.add_argument("--freeze_prototypes_epochs", default=1, type=int,
+                            help="freeze the prototypes during this many epochs from the start")
+
+        return parser
+
+
+def mean(res, key1, key2=None):
+    if key2 is not None:
+        return torch.stack([x[key1][key2] for x in res]).mean()
+    return torch.stack([x[key1] for x in res if type(x) == dict and key1 in x.keys()]).mean()
+
+def parse_args(parent_parser):
+    parser = ArgumentParser(parents=[parent_parser], add_help=False)
+    parser.add_argument('-t', '--trafos', nargs='+', help='add transformation to data augmentation pipeline',
+                        default=["GaussianNoise", "ChannelResize", "RandomResizedCrop"])
+    # GaussianNoise
+    parser.add_argument(
+            '--gaussian_scale', help='std param for gaussian noise transformation', default=0.005, type=float)
+    # RandomResizedCrop
+    parser.add_argument('--rr_crop_ratio_range',
+                            help='ratio range for random resized crop transformation', default=[0.5, 1.0], type=float)
+    parser.add_argument(
+            '--output_size', help='output size for random resized crop transformation', default=250, type=int)
+    # DynamicTimeWarp
+    parser.add_argument(
+            '--warps', help='number of warps for dynamic time warp transformation', default=3, type=int)
+    parser.add_argument(
+            '--radius', help='radius of warps of dynamic time warp transformation', default=10, type=int)
+    # TimeWarp
+    parser.add_argument(
+            '--epsilon', help='epsilon param for time warp', default=10, type=float)
+    # ChannelResize
+    parser.add_argument('--magnitude_range', nargs='+',
+                            help='range for scale param for ChannelResize transformation', default=[0.5, 2], type=float)
+    # Downsample
+    parser.add_argument(
+            '--downsample_ratio', help='downsample ratio for Downsample transformation', default=0.2, type=float)
+    # TimeOut
+    parser.add_argument('--to_crop_ratio_range', nargs='+',
+                            help='ratio range for timeout transformation', default=[0.2, 0.4], type=float)
+    # resume training
+    parser.add_argument('--resume', action='store_true')
+    parser.add_argument(
+            '--gpus', help='number of gpus to use; use cpu if gpu=0', type=int, default=1)
+    parser.add_argument(
+            '--num_nodes', default=1, help='number of cluster nodes', type=int)
+    parser.add_argument(
+            '--distributed_backend', help='sets backend type')
+    parser.add_argument('--batch_size', type=int)
+    parser.add_argument('--epochs', type=int)
+    parser.add_argument('--debug', action='store_true')
+    parser.add_argument('--warm_up', default=1, type=int)
+    parser.add_argument('--precision', type=int)
+    parser.add_argument('--datasets', dest="target_folders",
+                            nargs='+', help='used datasets for pretraining')
+    parser.add_argument('--log_dir', default="./experiment_logs")
+    parser.add_argument(
+            '--percentage', help='determines how much of the dataset shall be used during the pretraining', type=float, default=1.0)
+    parser.add_argument('--lr', type=float, help="learning rate")
+    parser.add_argument('--out_dim', type=int, help="output dimension of model")
+    parser.add_argument('--filter_cinc', default=False, action="store_true", help="only valid if cinc is selected: filter out the ptb data")
+    parser.add_argument('--base_model')
+    parser.add_argument('--widen',type=int, help="use wide xresnet1d50")
+    parser.add_argument('--run_callbacks', default=False, action="store_true", help="run callbacks which asses linear evaluaton and finetuning metrics during pretraining")
+
+    parser.add_argument('--checkpoint_path', default="")
+    return parser
+
+def init_logger(config):
+    level = logging.INFO
+
+    if config['debug']:
+        level = logging.DEBUG
+
+    # remove all handlers to change basic configuration
+    for handler in logging.root.handlers[:]:
+        logging.root.removeHandler(handler)
+    if not os.path.isdir(config['log_dir']):
+        os.mkdir(config['log_dir'])
+    logging.basicConfig(filename=os.path.join(config['log_dir'], 'info.log'), level=level,
+                        format='%(asctime)s %(name)s:%(lineno)s %(levelname)s:  %(message)s  ')
+    return logging.getLogger(__name__)
+
+def pretrain_routine(args):
+    t_params = {"gaussian_scale": args.gaussian_scale, "rr_crop_ratio_range": args.rr_crop_ratio_range, "output_size": args.output_size, "warps": args.warps, "radius": args.radius,
+                "epsilon": args.epsilon, "magnitude_range": args.magnitude_range, "downsample_ratio": args.downsample_ratio, "to_crop_ratio_range": args.to_crop_ratio_range,
+                "bw_cmax":0.1, "em_cmax":0.5, "pl_cmax":0.2, "bs_cmax":1}
+    transformations = args.trafos
+    checkpoint_config = os.path.join("checkpoints", "bolts_config.yaml")
+    config_file = checkpoint_config if args.resume and os.path.isfile(
+        checkpoint_config) else "bolts_config.yaml"
+    config = yaml.load(open(config_file, "r"), Loader=yaml.FullLoader)
+    args_dict = vars(args)
+    for key in set(config.keys()).union(set(args_dict.keys())):
+        config[key] = config[key] if (key not in args_dict.keys() or key in args_dict.keys(
+        ) and key in config.keys() and args_dict[key] is None) else args_dict[key]
+    if args.target_folders is not None:
+        config["dataset"]["target_folders"] = args.target_folders
+    config["dataset"]["percentage"] = args.percentage if args.percentage is not None else config["dataset"]["percentage"]
+    config["dataset"]["filter_cinc"] = args.filter_cinc if args.filter_cinc is not None else config["dataset"]["filter_cinc"]
+    config["model"]["base_model"] = args.base_model if args.base_model is not None else config["model"]["base_model"]
+    config["model"]["widen"] = args.widen if args.widen is not None else config["model"]["widen"]
+    config["dataset"]["swav"] = True
+    config["dataset"]["nmb_crops"] = 7
+    config["eval_dataset"]["swav"] = True
+    config["eval_dataset"]["nmb_crops"] = 7
+    if args.out_dim is not None:
+        config["model"]["out_dim"] = args.out_dim
+    init_logger(config)
+    dataset = SimCLRDataSetWrapper(
+        config['batch_size'], **config['dataset'], transformations=transformations, t_params=t_params)
+    for i, t in enumerate(dataset.transformations):
+        logger.info(str(i) + ". Transformation: " +
+                    str(t) + ": " + str(t.get_params()))
+    date = time.asctime()
+    label_to_num_classes = {"label_all": 71, "label_diag": 44, "label_form": 19,
+                            "label_rhythm": 12, "label_diag_subclass": 23, "label_diag_superclass": 5}
+    ptb_num_classes = label_to_num_classes[config["eval_dataset"]
+                                           ["ptb_xl_label"]]
+    abr = {"Transpose": "Tr", "TimeOut": "TO", "DynamicTimeWarp": "DTW", "RandomResizedCrop": "RRC", "ChannelResize": "ChR", "GaussianNoise": "GN",
+           "TimeWarp": "TW", "ToTensor": "TT", "GaussianBlur": "GB", "BaselineWander": "BlW", "PowerlineNoise": "PlN", "EMNoise": "EM", "BaselineShift": "BlS"}
+    trs = re.sub(r"[,'\]\[]", "", str([abr[str(tr)] if abr[str(tr)] not in [
+                 "TT", "Tr"] else '' for tr in dataset.transformations]))
+    name = str(date) + "_" + method + "_" + str(
+        time.time_ns())[-3:] + "_" + trs[1:]
+    tb_logger = TensorBoardLogger(args.log_dir, name=name, version='')
+    config["log_dir"] = os.path.join(args.log_dir, name)
+    print(config)
+    return config, dataset, date, transformations, t_params, ptb_num_classes, tb_logger
+
+def aftertrain_routine(config, args, trainer, pl_model, datamodule, callbacks):
+    scores = {}
+    for ca in callbacks:
+        if isinstance(ca, SSLOnlineEvaluator):
+            scores[str(ca)] = {"macro": ca.best_macro}
+
+    results = {"config": config, "trafos": args.trafos, "scores": scores}
+
+    with open(os.path.join(config["log_dir"], "results.pkl"), 'wb') as handle:
+        pickle.dump(results, handle)
+
+    trainer.save_checkpoint(os.path.join(config["log_dir"], "checkpoints", "model.ckpt"))
+    with open(os.path.join(config["log_dir"], "config.txt"), "w") as text_file:
+        print(config, file=text_file)
+
+def cli_main():
+    from pytorch_lightning import Trainer
+    from online_evaluator import SSLOnlineEvaluator
+    from ecg_datamodule import ECGDataModule
+    from clinical_ts.create_logger import create_logger
+    from os.path import exists
+    
+    parser = ArgumentParser()
+    parser = parse_args(parser)
+    logger.info("parse arguments")
+    args = parser.parse_args()
+    config, dataset, date, transformations, t_params, ptb_num_classes, tb_logger = pretrain_routine(args)
+
+    # data
+    ecg_datamodule = ECGDataModule(config, transformations, t_params)
+
+    callbacks = []
+    if args.run_callbacks:
+            # callback for online linear evaluation/fine-tuning
+        linear_evaluator = SSLOnlineEvaluator(drop_p=0,
+                                          z_dim=512, num_classes=ptb_num_classes, hidden_dim=None, lin_eval_epochs=config["eval_epochs"], eval_every=config["eval_every"], mode="linear_evaluation", verbose=False)
+
+        fine_tuner = SSLOnlineEvaluator(drop_p=0,
+                                          z_dim=512, num_classes=ptb_num_classes, hidden_dim=None, lin_eval_epochs=config["eval_epochs"], eval_every=config["eval_every"], mode="fine_tuning", verbose=False)
+   
+        callbacks.append(linear_evaluator)
+        callbacks.append(fine_tuner)
+
+    # configure trainer
+    trainer = Trainer(logger=tb_logger, max_epochs=config["epochs"], gpus=args.gpus,
+                      distributed_backend=args.distributed_backend, auto_lr_find=False, num_nodes=args.num_nodes, precision=config["precision"], callbacks=callbacks)
+
+    # pytorch lightning module
+    model = ResNetSimCLR(**config["model"])
+    pl_model = CustomSwAV(model,  config["gpus"], ecg_datamodule.num_samples, config["batch_size"], config=config,
+                              transformations=ecg_datamodule.transformations, nmb_crops=config["dataset"]["nmb_crops"])
+    # load checkpoint
+    if args.checkpoint_path != "":
+        if exists(args.checkpoint_path):
+            logger.info("Retrieve checkpoint from " + args.checkpoint_path)
+            pl_model.load_from_checkpoint(args.checkpoint_path)
+        else:
+            raise("checkpoint does not exist")
+
+    # start training
+    trainer.fit(pl_model, ecg_datamodule)
+
+    aftertrain_routine(config, args, trainer, pl_model, ecg_datamodule, callbacks)
+
+if __name__ == "__main__":  
+    cli_main()
\ No newline at end of file