"""
Facebook (FAIR), released under MIT License
"""
import torch
import torchvision
from torch import nn, optim
import math
from algorithms.arch.resnet import loadResnetBackbone
import utilities.runUtils as rutl
## Codes from BarlowTwin official implementation with distributed training blocks removed
##==================== Model ===============================================
class BarlowTwins(nn.Module):
def __init__(self, featx_arch, projector_sizes,
batch_size, lmbd = 0.0051, pretrained=None):
super().__init__()
rutl.START_SEED()
self.batch_size = batch_size
self.lmbd = lmbd
self.backbone, self.outfeatx_size = loadResnetBackbone(arch=featx_arch,
torch_pretrain=pretrained)
# backbone_out_shape + projector_dims
sizes = [self.outfeatx_size] + list(projector_sizes)
layers = []
for i in range(len(sizes) - 2):
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
layers.append(nn.BatchNorm1d(sizes[i + 1]))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
self.projector = nn.Sequential(*layers)
# normalization layer for the representations z1 and z2
self.bn = nn.BatchNorm1d(sizes[-1], affine=False)
def forward(self, y1, y2):
z1 = self.projector(self.backbone(y1))
z2 = self.projector(self.backbone(y2))
# empirical cross-correlation matrix
c = self.bn(z1).T @ self.bn(z2)
# sum the cross-correlation matrix between all gpus
c.div_(self.batch_size)
# torch.distributed.all_reduce(c)
on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
off_diag = self.off_diagonal(c).pow_(2).sum()
loss = on_diag + self.lmbd * off_diag
return loss
def off_diagonal(self, x):
# return a flattened view of the off-diagonal elements of a square matrix
n, m = x.shape
assert n == m
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
##==================== OPTIMISER ===============================================
class LARS(optim.Optimizer):
""" From Barlows twin example
"""
def __init__(self, params, lr, weight_decay=0, momentum=0.9, eta=0.001,
weight_decay_filter=False, lars_adaptation_filter=False):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
eta=eta, weight_decay_filter=weight_decay_filter,
lars_adaptation_filter=lars_adaptation_filter)
param_weights = []
param_biases = []
for param in params:
if param.ndim == 1:
param_biases.append(param)
else:
param_weights.append(param)
parameters = [{'params': param_weights}, {'params': param_biases}]
super().__init__(parameters, defaults)
def exclude_bias_and_norm(self, p):
return p.ndim == 1
@torch.no_grad()
def step(self):
for g in self.param_groups:
for p in g['params']:
dp = p.grad
if dp is None:
continue
if not g['weight_decay_filter'] or not self.exclude_bias_and_norm(p):
dp = dp.add(p, alpha=g['weight_decay'])
if not g['lars_adaptation_filter'] or not self.exclude_bias_and_norm(p):
param_norm = torch.norm(p)
update_norm = torch.norm(dp)
one = torch.ones_like(param_norm)
q = torch.where(param_norm > 0.,
torch.where(update_norm > 0,
(g['eta'] * param_norm / update_norm), one), one)
dp = dp.mul(q)
param_state = self.state[p]
if 'mu' not in param_state:
param_state['mu'] = torch.zeros_like(p)
mu = param_state['mu']
mu.mul_(g['momentum']).add_(dp)
p.add_(mu, alpha=-g['lr'])
def adjust_learning_rate(args, optimizer, loader, step):
max_steps = args.epochs * len(loader)
warmup_steps = 10 * len(loader)
base_lr = args.batch_size / 256
if step < warmup_steps:
lr = base_lr * step / warmup_steps
else:
step -= warmup_steps
max_steps -= warmup_steps
q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
end_lr = base_lr * 0.001
lr = base_lr * q + end_lr * (1 - q)
optimizer.param_groups[0]['lr'] = lr * args.learning_rate_weights
optimizer.param_groups[1]['lr'] = lr * args.learning_rate_biases