"""
Adapted from official swav implementation: https://github.com/facebookresearch/swav
"""
import math
import os
import re
from argparse import ArgumentParser
from typing import Callable, Optional
import pdb
import numpy as np
import pytorch_lightning as pl
import torch
import torch.distributed as dist
from pytorch_lightning.utilities import AMPType
from torch import nn
from pytorch_lightning.core.optimizer import LightningOptimizer
from torch.optim.optimizer import Optimizer
import yaml
import time
import logging
import pickle
# from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50
from pl_bolts.optimizers.lars_scheduling import LARSWrapper
from pl_bolts.transforms.dataset_normalizations import (
cifar10_normalization,
imagenet_normalization,
stl10_normalization,
)
from clinical_ts.simclr_dataset_wrapper import SimCLRDataSetWrapper
from clinical_ts.create_logger import create_logger
from torchvision.models.resnet import Bottleneck, BasicBlock
from online_evaluator import SSLOnlineEvaluator
from ecg_datamodule import ECGDataModule
from pytorch_lightning.loggers import TensorBoardLogger
from models.resnet_simclr import ResNetSimCLR
import torchvision.transforms as transforms
_TORCHVISION_AVAILABLE = True
# import cv2
from typing import List
logger = create_logger(__name__)
method = "swav"
class SwAVTrainDataTransform(object):
def __init__(
self,
normalize=None,
size_crops: List[int] = [96, 36],
nmb_crops: List[int] = [2, 4],
min_scale_crops: List[float] = [0.33, 0.10],
max_scale_crops: List[float] = [1, 0.33],
gaussian_blur: bool = True,
jitter_strength: float = 1.
):
self.jitter_strength = jitter_strength
self.gaussian_blur = gaussian_blur
assert len(size_crops) == len(nmb_crops)
assert len(min_scale_crops) == len(nmb_crops)
assert len(max_scale_crops) == len(nmb_crops)
self.size_crops = size_crops
self.nmb_crops = nmb_crops
self.min_scale_crops = min_scale_crops
self.max_scale_crops = max_scale_crops
self.color_jitter = transforms.ColorJitter(
0.8 * self.jitter_strength,
0.8 * self.jitter_strength,
0.8 * self.jitter_strength,
0.2 * self.jitter_strength
)
transform = []
color_transform = [
transforms.RandomApply([self.color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2)
]
if self.gaussian_blur:
kernel_size = int(0.1 * self.size_crops[0])
if kernel_size % 2 == 0:
kernel_size += 1
color_transform.append(
GaussianBlur(kernel_size=kernel_size, p=0.5)
)
self.color_transform = transforms.Compose(color_transform)
if normalize is None:
self.final_transform = transforms.ToTensor()
else:
self.final_transform = transforms.Compose(
[transforms.ToTensor(), normalize])
for i in range(len(self.size_crops)):
random_resized_crop = transforms.RandomResizedCrop(
self.size_crops[i],
scale=(self.min_scale_crops[i], self.max_scale_crops[i]),
)
transform.extend([transforms.Compose([
random_resized_crop,
transforms.RandomHorizontalFlip(p=0.5),
self.color_transform,
self.final_transform])
] * self.nmb_crops[i])
self.transform = transform
# add online train transform of the size of global view
online_train_transform = transforms.Compose([
transforms.RandomResizedCrop(self.size_crops[0]),
transforms.RandomHorizontalFlip(),
self.final_transform
])
self.transform.append(online_train_transform)
def __call__(self, sample):
multi_crops = list(
map(lambda transform: transform(sample), self.transform)
)
return multi_crops
class SwAVEvalDataTransform(SwAVTrainDataTransform):
def __init__(
self,
normalize=None,
size_crops: List[int] = [96, 36],
nmb_crops: List[int] = [2, 4],
min_scale_crops: List[float] = [0.33, 0.10],
max_scale_crops: List[float] = [1, 0.33],
gaussian_blur: bool = True,
jitter_strength: float = 1.
):
super().__init__(
normalize=normalize,
size_crops=size_crops,
nmb_crops=nmb_crops,
min_scale_crops=min_scale_crops,
max_scale_crops=max_scale_crops,
gaussian_blur=gaussian_blur,
jitter_strength=jitter_strength
)
input_height = self.size_crops[0] # get global view crop
test_transform = transforms.Compose([
transforms.Resize(int(input_height + 0.1 * input_height)),
transforms.CenterCrop(input_height),
self.final_transform,
])
# replace last transform to eval transform in self.transform list
self.transform[-1] = test_transform
class SwAVFinetuneTransform(object):
def __init__(
self,
input_height: int = 224,
jitter_strength: float = 1.,
normalize=None,
eval_transform: bool = False
) -> None:
self.jitter_strength = jitter_strength
self.input_height = input_height
self.normalize = normalize
self.color_jitter = transforms.ColorJitter(
0.8 * self.jitter_strength,
0.8 * self.jitter_strength,
0.8 * self.jitter_strength,
0.2 * self.jitter_strength
)
if not eval_transform:
data_transforms = [
transforms.RandomResizedCrop(size=self.input_height),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([self.color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2)
]
else:
data_transforms = [
transforms.Resize(
int(self.input_height + 0.1 * self.input_height)),
transforms.CenterCrop(self.input_height)
]
if normalize is None:
final_transform = transforms.ToTensor()
else:
final_transform = transforms.Compose(
[transforms.ToTensor(), normalize])
data_transforms.append(final_transform)
self.transform = transforms.Compose(data_transforms)
def __call__(self, sample):
return self.transform(sample)
class CustomResNet(nn.Module):
def __init__(
self,
model,
zero_init_residual=False,
output_dim=16,
hidden_mlp=512,
nmb_prototypes=8,
eval_mode=False,
first_conv=True,
maxpool1=True,
l2norm=True
):
super(CustomResNet, self).__init__()
self.l2norm = l2norm
self.model = model
self.features = self.model.features
self.projection_head = nn.Sequential(
nn.Linear(512, hidden_mlp),
nn.BatchNorm1d(hidden_mlp),
nn.ReLU(inplace=True),
nn.Linear(hidden_mlp, output_dim),
)
# prototype layer
self.prototypes = None
if isinstance(nmb_prototypes, list):
self.prototypes = MultiPrototypes(output_dim, nmb_prototypes)
elif nmb_prototypes > 0:
self.prototypes = nn.Linear(output_dim, nmb_prototypes, bias=False)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def forward_backbone(self, x):
x = x.type(self.features[0][0].weight.type())
h = self.features(x)
h = h.squeeze()
return h
def forward_head(self, x):
if self.projection_head is not None:
x = self.projection_head(x)
if self.l2norm:
x = nn.functional.normalize(x, dim=1, p=2)
if self.prototypes is not None:
return x, self.prototypes(x)
return x
def forward(self, inputs):
if not isinstance(inputs, list):
inputs = [inputs]
idx_crops = torch.cumsum(torch.unique_consecutive(
torch.tensor([inp.shape[-1] for inp in inputs]),
return_counts=True,
)[1], 0)
start_idx = 0
for end_idx in idx_crops:
_out = torch.cat(inputs[start_idx: end_idx])
if 'cuda' in str(self.features[0][0].weight.device):
_out = self.forward_backbone(_out.cuda(non_blocking=True))
else:
_out = self.forward_backbone(_out)
if start_idx == 0:
output = _out
else:
output = torch.cat((output, _out))
start_idx = end_idx
return self.forward_head(output)
class MultiPrototypes(nn.Module):
def __init__(self, output_dim, nmb_prototypes):
super(MultiPrototypes, self).__init__()
self.nmb_heads = len(nmb_prototypes)
for i, k in enumerate(nmb_prototypes):
self.add_module("prototypes" + str(i),
nn.Linear(output_dim, k, bias=False))
def forward(self, x):
out = []
for i in range(self.nmb_heads):
out.append(getattr(self, "prototypes" + str(i))(x))
return out
class CustomSwAV(pl.LightningModule):
def __init__(
self,
model,
gpus: int,
num_samples: int,
batch_size: int,
config=None,
transformations=None,
nodes: int = 1,
arch: str = 'resnet50',
hidden_mlp: int = 2048,
feat_dim: int = 128,
warmup_epochs: int = 10,
max_epochs: int = 100,
nmb_prototypes: int = 3000,
freeze_prototypes_epochs: int = 1,
temperature: float = 0.1,
sinkhorn_iterations: int = 3,
# queue_length: int = 512, # must be divisible by total batch-size
queue_path: str = "queue",
epoch_queue_starts: int = 15,
crops_for_assign: list = [0, 1],
nmb_crops: list = [2, 6],
first_conv: bool = True,
maxpool1: bool = True,
optimizer: str = 'adam',
lars_wrapper: bool = False,
exclude_bn_bias: bool = False,
start_lr: float = 0.,
learning_rate: float = 1e-3,
final_lr: float = 0.,
weight_decay: float = 1e-6,
epsilon: float = 0.05,
**kwargs
):
"""
Args:
gpus: number of gpus per node used in training, passed to SwAV module
to manage the queue and select distributed sinkhorn
nodes: number of nodes to train on
num_samples: number of image samples used for training
batch_size: batch size per GPU in ddp
dataset: dataset being used for train/val
arch: encoder architecture used for pre-training
hidden_mlp: hidden layer of non-linear projection head, set to 0
to use a linear projection head
feat_dim: output dim of the projection head
warmup_epochs: apply linear warmup for this many epochs
max_epochs: epoch count for pre-training
nmb_prototypes: count of prototype vectors
freeze_prototypes_epochs: epoch till which gradients of prototype layer
are frozen
temperature: loss temperature
sinkhorn_iterations: iterations for sinkhorn normalization
queue_length: set queue when batch size is small,
must be divisible by total batch-size (i.e. total_gpus * batch_size),
set to 0 to remove the queue
queue_path: folder within the logs directory
epoch_queue_starts: start uing the queue after this epoch
crops_for_assign: list of crop ids for computing assignment
nmb_crops: number of global and local crops, ex: [2, 6]
first_conv: keep first conv same as the original resnet architecture,
if set to false it is replace by a kernel 3, stride 1 conv (cifar-10)
maxpool1: keep first maxpool layer same as the original resnet architecture,
if set to false, first maxpool is turned off (cifar10, maybe stl10)
optimizer: optimizer to use
lars_wrapper: use LARS wrapper over the optimizer
exclude_bn_bias: exclude batchnorm and bias layers from weight decay in optimizers
start_lr: starting lr for linear warmup
learning_rate: learning rate
final_lr: float = final learning rate for cosine weight decay
weight_decay: weight decay for optimizer
epsilon: epsilon val for swav assignments
"""
super().__init__()
# self.save_hyperparameters()
self.epoch = 0
self.config = config
self.transformations = transformations
self.gpus = gpus
self.nodes = nodes
self.arch = arch
self.num_samples = num_samples
self.batch_size = batch_size
self.queue_length = 8*batch_size
self.hidden_mlp = hidden_mlp
self.feat_dim = feat_dim
self.nmb_prototypes = nmb_prototypes
self.freeze_prototypes_epochs = freeze_prototypes_epochs
self.sinkhorn_iterations = sinkhorn_iterations
#self.queue_length = queue_length
self.queue_path = queue_path
self.epoch_queue_starts = epoch_queue_starts
self.crops_for_assign = crops_for_assign
self.nmb_crops = nmb_crops
self.first_conv = first_conv
self.maxpool1 = maxpool1
self.optim = optimizer
self.lars_wrapper = lars_wrapper
self.exclude_bn_bias = exclude_bn_bias
self.weight_decay = weight_decay
self.epsilon = epsilon
self.temperature = temperature
self.start_lr = start_lr
self.final_lr = final_lr
self.learning_rate = learning_rate
self.warmup_epochs = warmup_epochs
self.max_epochs = config["epochs"]
if self.gpus * self.nodes > 1:
self.get_assignments = self.distributed_sinkhorn
else:
self.get_assignments = self.sinkhorn
# compute iters per epoch
global_batch_size = self.nodes * self.gpus * \
self.batch_size if self.gpus > 0 else self.batch_size
self.train_iters_per_epoch = (self.num_samples // global_batch_size)+1
# define LR schedule
warmup_lr_schedule = np.linspace(
self.start_lr, self.learning_rate, self.train_iters_per_epoch * self.warmup_epochs
)
iters = np.arange(self.train_iters_per_epoch *
(self.max_epochs - self.warmup_epochs))
cosine_lr_schedule = np.array([self.final_lr + 0.5 * (self.learning_rate - self.final_lr) * (
1 + math.cos(math.pi * t / (self.train_iters_per_epoch *
(self.max_epochs - self.warmup_epochs)))
) for t in iters])
self.lr_schedule = np.concatenate(
(warmup_lr_schedule, cosine_lr_schedule))
self.queue = None
self.model = self.init_model(model)
self.softmax = nn.Softmax(dim=1)
def setup(self, stage):
queue_folder = os.path.join(self.config["log_dir"], self.queue_path)
if not os.path.exists(queue_folder):
os.makedirs(queue_folder)
self.queue_path = os.path.join(
queue_folder,
"queue" + str(self.trainer.global_rank) + ".pth"
)
if os.path.isfile(self.queue_path):
self.queue = torch.load(self.queue_path)["queue"]
def init_model(self, model):
return CustomResNet(model, hidden_mlp=self.hidden_mlp,
output_dim=self.feat_dim,
nmb_prototypes=self.nmb_prototypes,
first_conv=self.first_conv,
maxpool1=self.maxpool1)
def forward(self, x):
# pass single batch from the resnet backbone
return self.model.forward_backbone(x)
def on_train_start(self):
# # log configuration
# config_str = re.sub(r"[,\}\{]", "<br/>", str(self.config))
# config_str = re.sub(r"[\[\]\']", "", config_str)
# transformation_str = re.sub(r"[\}]", "<br/>", str(["<br>" + str(
# t) + ":<br/>" + str(t.get_params()) for t in self.transformations]))
# transformation_str = re.sub(r"[,\"\{\'\[\]]", "", transformation_str)
# self.logger.experiment.add_text(
# "configuration", str(config_str), global_step=0)
# self.logger.experiment.add_text("transformations", str(
# transformation_str), global_step=0)
self.epoch = 0
def on_train_epoch_start(self):
if self.queue_length > 0:
if self.trainer.current_epoch >= self.epoch_queue_starts and self.queue is None:
self.queue = torch.zeros(
len(self.crops_for_assign),
self.queue_length // self.gpus, # change to nodes * gpus once multi-node
self.feat_dim,
)
if self.gpus > 0:
self.queue = self.queue.cuda()
self.use_the_queue = False
def on_train_epoch_end(self, outputs) -> None:
if self.queue is not None:
torch.save({"queue": self.queue}, self.queue_path)
def on_epoch_end(self):
self.epoch += 1
def on_after_backward(self):
if self.current_epoch < self.freeze_prototypes_epochs:
for name, p in self.model.named_parameters():
if "prototypes" in name:
p.grad = None
def shared_step(self, batch):
# if self.dataset == 'stl10':
# unlabeled_batch = batch[0]
# batch = unlabeled_batch
inputs, y = batch
# remove online train/eval transforms at this point
inputs = inputs[:-1]
# 1. normalize the prototypes
with torch.no_grad():
w = self.model.prototypes.weight.data.clone()
w = nn.functional.normalize(w, dim=1, p=2)
self.model.prototypes.weight.copy_(w)
# 2. multi-res forward passes
embedding, output = self.model(inputs)
embedding = embedding.detach()
bs = inputs[0].size(0)
# 3. swav loss computation
loss = 0
for i, crop_id in enumerate(self.crops_for_assign):
with torch.no_grad():
out = output[bs * crop_id: bs * (crop_id + 1)]
# 4. time to use the queue
if self.queue is not None:
if self.use_the_queue or not torch.all(self.queue[i, -1, :] == 0):
self.use_the_queue = True
out = torch.cat((torch.mm(
self.queue[i],
self.model.prototypes.weight.t()
), out))
# fill the queue
self.queue[i, bs:] = self.queue[i, :-bs].clone()
self.queue[i, :bs] = embedding[crop_id *
bs: (crop_id + 1) * bs]
# 5. get assignments
q = torch.exp(out / self.epsilon).t()
q = self.get_assignments(q, self.sinkhorn_iterations)[-bs:]
# cluster assignment prediction
subloss = 0
for v in np.delete(np.arange(np.sum(self.nmb_crops-1)), crop_id):
p = self.softmax(
output[bs * v: bs * (v + 1)] / self.temperature)
loss_value = q * torch.log(p)
subloss -= torch.mean(torch.sum(loss_value, dim=1))
loss += subloss / (np.sum(self.nmb_crops) - 1)
loss /= len(self.crops_for_assign)
return loss
def training_step(self, batch, batch_idx):
loss = self.shared_step(batch)
# self.log('train_loss', loss, on_step=True, on_epoch=False)
return loss
def validation_step(self, batch, batch_idx, dataloader_idx):
if dataloader_idx != 0:
return {}
loss = self.shared_step(batch)
# self.log('val_loss', loss, on_step=False, on_epoch=True)
results = {
'val_loss': loss,
}
return results
def validation_epoch_end(self, outputs):
# outputs[0] because we are using multiple datasets!
val_loss = mean(outputs[0], 'val_loss')
log = {
'val/val_loss': val_loss,
}
return {'val_loss': val_loss, 'log': log, 'progress_bar': log}
def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=['bias', 'bn']):
params = []
excluded_params = []
for name, param in named_params:
if not param.requires_grad:
continue
elif any(layer_name in name for layer_name in skip_list):
excluded_params.append(param)
else:
params.append(param)
return [
{'params': params, 'weight_decay': weight_decay},
{'params': excluded_params, 'weight_decay': 0.}
]
def configure_optimizers(self):
if self.exclude_bn_bias:
params = self.exclude_from_wt_decay(
self.named_parameters(),
weight_decay=self.weight_decay
)
else:
params = self.parameters()
if self.optim == 'sgd':
optimizer = torch.optim.SGD(
params,
lr=self.learning_rate,
momentum=0.9,
weight_decay=self.weight_decay
)
elif self.optim == 'adam':
optimizer = torch.optim.Adam(
params,
lr=self.learning_rate,
weight_decay=self.weight_decay
)
if self.lars_wrapper:
optimizer = LARSWrapper(
optimizer,
eta=0.001, # trust coefficient
clip=False
)
return optimizer
def optimizer_step(
self,
epoch: int = None,
batch_idx: int = None,
optimizer: Optimizer = None,
optimizer_idx: int = None,
optimizer_closure: Optional[Callable] = None,
on_tpu: bool = None,
using_native_amp: bool = None,
using_lbfgs: bool = None,
) -> None:
# warm-up + decay schedule placed here since LARSWrapper is not optimizer class
# adjust LR of optim contained within LARSWrapper
for param_group in optimizer.param_groups:
param_group["lr"] = self.lr_schedule[self.trainer.global_step]
# from lightning
if not isinstance(optimizer, LightningOptimizer):
# wraps into LightingOptimizer only for running step
optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer)
optimizer.step(closure=optimizer_closure)
def sinkhorn(self, Q, nmb_iters):
with torch.no_grad():
sum_Q = torch.sum(Q)
Q /= sum_Q
K, B = Q.shape
if self.gpus > 0:
u = torch.zeros(K).cuda()
r = torch.ones(K).cuda() / K
c = torch.ones(B).cuda() / B
else:
u = torch.zeros(K)
r = torch.ones(K) / K
c = torch.ones(B) / B
for _ in range(nmb_iters):
u = torch.sum(Q, dim=1)
Q *= (r / u).unsqueeze(1)
Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)
return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float()
def distributed_sinkhorn(self, Q, nmb_iters):
with torch.no_grad():
sum_Q = torch.sum(Q)
dist.all_reduce(sum_Q)
Q /= sum_Q
if self.gpus > 0:
u = torch.zeros(Q.shape[0]).cuda(non_blocking=True)
r = torch.ones(Q.shape[0]).cuda(non_blocking=True) / Q.shape[0]
c = torch.ones(Q.shape[1]).cuda(
non_blocking=True) / (self.gpus * Q.shape[1])
else:
u = torch.zeros(Q.shape[0])
r = torch.ones(Q.shape[0]) / Q.shape[0]
c = torch.ones(Q.shape[1]) / (self.gpus * Q.shape[1])
curr_sum = torch.sum(Q, dim=1)
dist.all_reduce(curr_sum)
for it in range(nmb_iters):
u = curr_sum
Q *= (r / u).unsqueeze(1)
Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)
curr_sum = torch.sum(Q, dim=1)
dist.all_reduce(curr_sum)
return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float()
def type(self):
return self.model.features[0][0].weight.type()
def get_representations(self, x):
return self.model.features(x)
def get_model(self):
return self.model.model
def get_device(self):
return self.model.features[0][0].weight.device
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
# model params
parser.add_argument("--arch", default="resnet50",
type=str, help="convnet architecture")
# specify flags to store false
parser.add_argument("--first_conv", action='store_false')
parser.add_argument("--maxpool1", action='store_false')
parser.add_argument("--hidden_mlp", default=2048, type=int,
help="hidden layer dimension in projection head")
parser.add_argument("--feat_dim", default=128,
type=int, help="feature dimension")
parser.add_argument("--online_ft", action='store_true')
parser.add_argument("--fp32", action='store_true')
# transform params
parser.add_argument("--gaussian_blur",
action="store_true", help="add gaussian blur")
parser.add_argument("--jitter_strength", type=float,
default=1.0, help="jitter strength")
parser.add_argument("--dataset", type=str,
default="stl10", help="stl10, cifar10")
parser.add_argument("--data_dir", type=str,
default=".", help="path to download data")
parser.add_argument("--queue_path", type=str,
default="queue", help="path for queue")
parser.add_argument("--nmb_crops", type=int, default=[2, 4], nargs="+",
help="list of number of crops (example: [2, 6])")
parser.add_argument("--size_crops", type=int, default=[96, 36], nargs="+",
help="crops resolutions (example: [224, 96])")
parser.add_argument("--min_scale_crops", type=float, default=[0.33, 0.10], nargs="+",
help="argument in RandomResizedCrop (example: [0.14, 0.05])")
parser.add_argument("--max_scale_crops", type=float, default=[1, 0.33], nargs="+",
help="argument in RandomResizedCrop (example: [1., 0.14])")
# training params
parser.add_argument("--fast_dev_run", action='store_true')
parser.add_argument("--nodes", default=1, type=int,
help="number of nodes for training")
parser.add_argument("--gpus", default=1, type=int,
help="number of gpus to train on")
parser.add_argument("--num_workers", default=8,
type=int, help="num of workers per GPU")
parser.add_argument("--optimizer", default="adam",
type=str, help="choose between adam/sgd")
parser.add_argument("--lars_wrapper", action='store_true',
help="apple lars wrapper over optimizer used")
parser.add_argument('--exclude_bn_bias', action='store_true',
help="exclude bn/bias from weight decay")
parser.add_argument("--max_epochs", default=100,
type=int, help="number of total epochs to run")
parser.add_argument("--max_steps", default=-1,
type=int, help="max steps")
parser.add_argument("--warmup_epochs", default=10,
type=int, help="number of warmup epochs")
parser.add_argument("--batch_size", default=128,
type=int, help="batch size per gpu")
parser.add_argument("--weight_decay", default=1e-6,
type=float, help="weight decay")
parser.add_argument("--learning_rate", default=1e-3,
type=float, help="base learning rate")
parser.add_argument("--start_lr", default=0, type=float,
help="initial warmup learning rate")
parser.add_argument("--final_lr", type=float,
default=1e-6, help="final learning rate")
# swav params
parser.add_argument("--crops_for_assign", type=int, nargs="+", default=[0, 1],
help="list of crops id used for computing assignments")
parser.add_argument("--temperature", default=0.1, type=float,
help="temperature parameter in training loss")
parser.add_argument("--epsilon", default=0.05, type=float,
help="regularization parameter for Sinkhorn-Knopp algorithm")
parser.add_argument("--sinkhorn_iterations", default=3, type=int,
help="number of iterations in Sinkhorn-Knopp algorithm")
parser.add_argument("--nmb_prototypes", default=512,
type=int, help="number of prototypes")
parser.add_argument("--queue_length", type=int, default=0,
help="length of the queue (0 for no queue); must be divisible by total batch size")
parser.add_argument("--epoch_queue_starts", type=int, default=15,
help="from this epoch, we start using a queue")
parser.add_argument("--freeze_prototypes_epochs", default=1, type=int,
help="freeze the prototypes during this many epochs from the start")
return parser
def mean(res, key1, key2=None):
if key2 is not None:
return torch.stack([x[key1][key2] for x in res]).mean()
return torch.stack([x[key1] for x in res if type(x) == dict and key1 in x.keys()]).mean()
def parse_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('-t', '--trafos', nargs='+', help='add transformation to data augmentation pipeline',
default=["GaussianNoise", "ChannelResize", "RandomResizedCrop"])
# GaussianNoise
parser.add_argument(
'--gaussian_scale', help='std param for gaussian noise transformation', default=0.005, type=float)
# RandomResizedCrop
parser.add_argument('--rr_crop_ratio_range',
help='ratio range for random resized crop transformation', default=[0.5, 1.0], type=float)
parser.add_argument(
'--output_size', help='output size for random resized crop transformation', default=250, type=int)
# DynamicTimeWarp
parser.add_argument(
'--warps', help='number of warps for dynamic time warp transformation', default=3, type=int)
parser.add_argument(
'--radius', help='radius of warps of dynamic time warp transformation', default=10, type=int)
# TimeWarp
parser.add_argument(
'--epsilon', help='epsilon param for time warp', default=10, type=float)
# ChannelResize
parser.add_argument('--magnitude_range', nargs='+',
help='range for scale param for ChannelResize transformation', default=[0.5, 2], type=float)
# Downsample
parser.add_argument(
'--downsample_ratio', help='downsample ratio for Downsample transformation', default=0.2, type=float)
# TimeOut
parser.add_argument('--to_crop_ratio_range', nargs='+',
help='ratio range for timeout transformation', default=[0.2, 0.4], type=float)
# resume training
parser.add_argument('--resume', action='store_true')
parser.add_argument(
'--gpus', help='number of gpus to use; use cpu if gpu=0', type=int, default=1)
parser.add_argument(
'--num_nodes', default=1, help='number of cluster nodes', type=int)
parser.add_argument(
'--distributed_backend', help='sets backend type')
parser.add_argument('--batch_size', type=int)
parser.add_argument('--epochs', type=int)
parser.add_argument('--debug', action='store_true')
parser.add_argument('--warm_up', default=1, type=int)
parser.add_argument('--precision', type=int)
parser.add_argument('--datasets', dest="target_folders",
nargs='+', help='used datasets for pretraining')
parser.add_argument('--log_dir', default="./experiment_logs")
parser.add_argument(
'--percentage', help='determines how much of the dataset shall be used during the pretraining', type=float, default=1.0)
parser.add_argument('--lr', type=float, help="learning rate")
parser.add_argument('--out_dim', type=int, help="output dimension of model")
parser.add_argument('--filter_cinc', default=False, action="store_true", help="only valid if cinc is selected: filter out the ptb data")
parser.add_argument('--base_model')
parser.add_argument('--widen',type=int, help="use wide xresnet1d50")
parser.add_argument('--run_callbacks', default=False, action="store_true", help="run callbacks which asses linear evaluaton and finetuning metrics during pretraining")
parser.add_argument('--checkpoint_path', default="")
return parser
def init_logger(config):
level = logging.INFO
if config['debug']:
level = logging.DEBUG
# remove all handlers to change basic configuration
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
if not os.path.isdir(config['log_dir']):
os.mkdir(config['log_dir'])
logging.basicConfig(filename=os.path.join(config['log_dir'], 'info.log'), level=level,
format='%(asctime)s %(name)s:%(lineno)s %(levelname)s: %(message)s ')
return logging.getLogger(__name__)
def pretrain_routine(args):
t_params = {"gaussian_scale": args.gaussian_scale, "rr_crop_ratio_range": args.rr_crop_ratio_range, "output_size": args.output_size, "warps": args.warps, "radius": args.radius,
"epsilon": args.epsilon, "magnitude_range": args.magnitude_range, "downsample_ratio": args.downsample_ratio, "to_crop_ratio_range": args.to_crop_ratio_range,
"bw_cmax":0.1, "em_cmax":0.5, "pl_cmax":0.2, "bs_cmax":1}
transformations = args.trafos
checkpoint_config = os.path.join("checkpoints", "bolts_config.yaml")
config_file = checkpoint_config if args.resume and os.path.isfile(
checkpoint_config) else "bolts_config.yaml"
config = yaml.load(open(config_file, "r"), Loader=yaml.FullLoader)
args_dict = vars(args)
for key in set(config.keys()).union(set(args_dict.keys())):
config[key] = config[key] if (key not in args_dict.keys() or key in args_dict.keys(
) and key in config.keys() and args_dict[key] is None) else args_dict[key]
if args.target_folders is not None:
config["dataset"]["target_folders"] = args.target_folders
config["dataset"]["percentage"] = args.percentage if args.percentage is not None else config["dataset"]["percentage"]
config["dataset"]["filter_cinc"] = args.filter_cinc if args.filter_cinc is not None else config["dataset"]["filter_cinc"]
config["model"]["base_model"] = args.base_model if args.base_model is not None else config["model"]["base_model"]
config["model"]["widen"] = args.widen if args.widen is not None else config["model"]["widen"]
config["dataset"]["swav"] = True
config["dataset"]["nmb_crops"] = 7
config["eval_dataset"]["swav"] = True
config["eval_dataset"]["nmb_crops"] = 7
if args.out_dim is not None:
config["model"]["out_dim"] = args.out_dim
init_logger(config)
dataset = SimCLRDataSetWrapper(
config['batch_size'], **config['dataset'], transformations=transformations, t_params=t_params)
for i, t in enumerate(dataset.transformations):
logger.info(str(i) + ". Transformation: " +
str(t) + ": " + str(t.get_params()))
date = time.asctime()
label_to_num_classes = {"label_all": 71, "label_diag": 44, "label_form": 19,
"label_rhythm": 12, "label_diag_subclass": 23, "label_diag_superclass": 5}
ptb_num_classes = label_to_num_classes[config["eval_dataset"]
["ptb_xl_label"]]
abr = {"Transpose": "Tr", "TimeOut": "TO", "DynamicTimeWarp": "DTW", "RandomResizedCrop": "RRC", "ChannelResize": "ChR", "GaussianNoise": "GN",
"TimeWarp": "TW", "ToTensor": "TT", "GaussianBlur": "GB", "BaselineWander": "BlW", "PowerlineNoise": "PlN", "EMNoise": "EM", "BaselineShift": "BlS"}
trs = re.sub(r"[,'\]\[]", "", str([abr[str(tr)] if abr[str(tr)] not in [
"TT", "Tr"] else '' for tr in dataset.transformations]))
name = str(date) + "_" + method + "_" + str(
time.time_ns())[-3:] + "_" + trs[1:]
tb_logger = TensorBoardLogger(args.log_dir, name=name, version='')
config["log_dir"] = os.path.join(args.log_dir, name)
print(config)
return config, dataset, date, transformations, t_params, ptb_num_classes, tb_logger
def aftertrain_routine(config, args, trainer, pl_model, datamodule, callbacks):
scores = {}
for ca in callbacks:
if isinstance(ca, SSLOnlineEvaluator):
scores[str(ca)] = {"macro": ca.best_macro}
results = {"config": config, "trafos": args.trafos, "scores": scores}
with open(os.path.join(config["log_dir"], "results.pkl"), 'wb') as handle:
pickle.dump(results, handle)
trainer.save_checkpoint(os.path.join(config["log_dir"], "checkpoints", "model.ckpt"))
with open(os.path.join(config["log_dir"], "config.txt"), "w") as text_file:
print(config, file=text_file)
def cli_main():
from pytorch_lightning import Trainer
from online_evaluator import SSLOnlineEvaluator
from ecg_datamodule import ECGDataModule
from clinical_ts.create_logger import create_logger
from os.path import exists
parser = ArgumentParser()
parser = parse_args(parser)
logger.info("parse arguments")
args = parser.parse_args()
config, dataset, date, transformations, t_params, ptb_num_classes, tb_logger = pretrain_routine(args)
# data
ecg_datamodule = ECGDataModule(config, transformations, t_params)
callbacks = []
if args.run_callbacks:
# callback for online linear evaluation/fine-tuning
linear_evaluator = SSLOnlineEvaluator(drop_p=0,
z_dim=512, num_classes=ptb_num_classes, hidden_dim=None, lin_eval_epochs=config["eval_epochs"], eval_every=config["eval_every"], mode="linear_evaluation", verbose=False)
fine_tuner = SSLOnlineEvaluator(drop_p=0,
z_dim=512, num_classes=ptb_num_classes, hidden_dim=None, lin_eval_epochs=config["eval_epochs"], eval_every=config["eval_every"], mode="fine_tuning", verbose=False)
callbacks.append(linear_evaluator)
callbacks.append(fine_tuner)
# configure trainer
trainer = Trainer(logger=tb_logger, max_epochs=config["epochs"], gpus=args.gpus,
distributed_backend=args.distributed_backend, auto_lr_find=False, num_nodes=args.num_nodes, precision=config["precision"], callbacks=callbacks)
# pytorch lightning module
model = ResNetSimCLR(**config["model"])
pl_model = CustomSwAV(model, config["gpus"], ecg_datamodule.num_samples, config["batch_size"], config=config,
transformations=ecg_datamodule.transformations, nmb_crops=config["dataset"]["nmb_crops"])
# load checkpoint
if args.checkpoint_path != "":
if exists(args.checkpoint_path):
logger.info("Retrieve checkpoint from " + args.checkpoint_path)
pl_model.load_from_checkpoint(args.checkpoint_path)
else:
raise("checkpoint does not exist")
# start training
trainer.fit(pl_model, ecg_datamodule)
aftertrain_routine(config, args, trainer, pl_model, ecg_datamodule, callbacks)
if __name__ == "__main__":
cli_main()