--- a +++ b/algorithms/barlowtwins.py @@ -0,0 +1,134 @@ +""" +Facebook (FAIR), released under MIT License + +""" +import torch +import torchvision +from torch import nn, optim +import math + +from algorithms.arch.resnet import loadResnetBackbone +import utilities.runUtils as rutl + +## Codes from BarlowTwin official implementation with distributed training blocks removed +##==================== Model =============================================== + +class BarlowTwins(nn.Module): + def __init__(self, featx_arch, projector_sizes, + batch_size, lmbd = 0.0051, pretrained=None): + super().__init__() + rutl.START_SEED() + + self.batch_size = batch_size + self.lmbd = lmbd + + self.backbone, self.outfeatx_size = loadResnetBackbone(arch=featx_arch, + torch_pretrain=pretrained) + + # backbone_out_shape + projector_dims + sizes = [self.outfeatx_size] + list(projector_sizes) + layers = [] + for i in range(len(sizes) - 2): + layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False)) + layers.append(nn.BatchNorm1d(sizes[i + 1])) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False)) + self.projector = nn.Sequential(*layers) + + # normalization layer for the representations z1 and z2 + self.bn = nn.BatchNorm1d(sizes[-1], affine=False) + + def forward(self, y1, y2): + z1 = self.projector(self.backbone(y1)) + z2 = self.projector(self.backbone(y2)) + + # empirical cross-correlation matrix + c = self.bn(z1).T @ self.bn(z2) + + # sum the cross-correlation matrix between all gpus + c.div_(self.batch_size) + # torch.distributed.all_reduce(c) + + on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() + off_diag = self.off_diagonal(c).pow_(2).sum() + loss = on_diag + self.lmbd * off_diag + return loss + + + def off_diagonal(self, x): + # return a flattened view of the off-diagonal elements of a square matrix + n, m = x.shape + assert n == m + return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() + + +##==================== OPTIMISER =============================================== + +class LARS(optim.Optimizer): + """ From Barlows twin example + """ + def __init__(self, params, lr, weight_decay=0, momentum=0.9, eta=0.001, + weight_decay_filter=False, lars_adaptation_filter=False): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, + eta=eta, weight_decay_filter=weight_decay_filter, + lars_adaptation_filter=lars_adaptation_filter) + param_weights = [] + param_biases = [] + for param in params: + if param.ndim == 1: + param_biases.append(param) + else: + param_weights.append(param) + parameters = [{'params': param_weights}, {'params': param_biases}] + + super().__init__(parameters, defaults) + + def exclude_bias_and_norm(self, p): + return p.ndim == 1 + + @torch.no_grad() + def step(self): + for g in self.param_groups: + for p in g['params']: + dp = p.grad + + if dp is None: + continue + + if not g['weight_decay_filter'] or not self.exclude_bias_and_norm(p): + dp = dp.add(p, alpha=g['weight_decay']) + + if not g['lars_adaptation_filter'] or not self.exclude_bias_and_norm(p): + param_norm = torch.norm(p) + update_norm = torch.norm(dp) + one = torch.ones_like(param_norm) + q = torch.where(param_norm > 0., + torch.where(update_norm > 0, + (g['eta'] * param_norm / update_norm), one), one) + dp = dp.mul(q) + + param_state = self.state[p] + if 'mu' not in param_state: + param_state['mu'] = torch.zeros_like(p) + mu = param_state['mu'] + mu.mul_(g['momentum']).add_(dp) + + p.add_(mu, alpha=-g['lr']) + + + +def adjust_learning_rate(args, optimizer, loader, step): + max_steps = args.epochs * len(loader) + warmup_steps = 10 * len(loader) + base_lr = args.batch_size / 256 + if step < warmup_steps: + lr = base_lr * step / warmup_steps + else: + step -= warmup_steps + max_steps -= warmup_steps + q = 0.5 * (1 + math.cos(math.pi * step / max_steps)) + end_lr = base_lr * 0.001 + lr = base_lr * q + end_lr * (1 - q) + optimizer.param_groups[0]['lr'] = lr * args.learning_rate_weights + optimizer.param_groups[1]['lr'] = lr * args.learning_rate_biases +