Diff of /algorithms/barlowtwins.py [000000] .. [a18f15]

Switch to side-by-side view

--- a
+++ b/algorithms/barlowtwins.py
@@ -0,0 +1,134 @@
+"""
+Facebook (FAIR), released under MIT License
+
+"""
+import torch
+import torchvision
+from torch import nn, optim
+import math
+
+from algorithms.arch.resnet import loadResnetBackbone
+import utilities.runUtils as rutl
+
+## Codes from BarlowTwin official implementation with distributed training blocks removed
+##==================== Model ===============================================
+
+class BarlowTwins(nn.Module):
+    def __init__(self, featx_arch, projector_sizes,
+                 batch_size, lmbd = 0.0051, pretrained=None):
+        super().__init__()
+        rutl.START_SEED()
+
+        self.batch_size =  batch_size
+        self.lmbd = lmbd
+
+        self.backbone, self.outfeatx_size = loadResnetBackbone(arch=featx_arch,
+                                            torch_pretrain=pretrained)
+
+        # backbone_out_shape + projector_dims
+        sizes = [self.outfeatx_size] + list(projector_sizes)
+        layers = []
+        for i in range(len(sizes) - 2):
+            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
+            layers.append(nn.BatchNorm1d(sizes[i + 1]))
+            layers.append(nn.ReLU(inplace=True))
+        layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
+        self.projector = nn.Sequential(*layers)
+
+        # normalization layer for the representations z1 and z2
+        self.bn = nn.BatchNorm1d(sizes[-1], affine=False)
+
+    def forward(self, y1, y2):
+        z1 = self.projector(self.backbone(y1))
+        z2 = self.projector(self.backbone(y2))
+
+        # empirical cross-correlation matrix
+        c = self.bn(z1).T @ self.bn(z2)
+
+        # sum the cross-correlation matrix between all gpus
+        c.div_(self.batch_size)
+        # torch.distributed.all_reduce(c)
+
+        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
+        off_diag = self.off_diagonal(c).pow_(2).sum()
+        loss = on_diag + self.lmbd * off_diag
+        return loss
+
+
+    def off_diagonal(self, x):
+        # return a flattened view of the off-diagonal elements of a square matrix
+        n, m = x.shape
+        assert n == m
+        return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
+
+
+##==================== OPTIMISER ===============================================
+
+class LARS(optim.Optimizer):
+    """ From Barlows twin example
+    """
+    def __init__(self, params, lr, weight_decay=0, momentum=0.9, eta=0.001,
+                 weight_decay_filter=False, lars_adaptation_filter=False):
+        defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
+                        eta=eta, weight_decay_filter=weight_decay_filter,
+                        lars_adaptation_filter=lars_adaptation_filter)
+        param_weights = []
+        param_biases = []
+        for param in params:
+            if param.ndim == 1:
+                param_biases.append(param)
+            else:
+                param_weights.append(param)
+        parameters = [{'params': param_weights}, {'params': param_biases}]
+
+        super().__init__(parameters, defaults)
+
+    def exclude_bias_and_norm(self, p):
+        return p.ndim == 1
+
+    @torch.no_grad()
+    def step(self):
+        for g in self.param_groups:
+            for p in g['params']:
+                dp = p.grad
+
+                if dp is None:
+                    continue
+
+                if not g['weight_decay_filter'] or not self.exclude_bias_and_norm(p):
+                    dp = dp.add(p, alpha=g['weight_decay'])
+
+                if not g['lars_adaptation_filter'] or not self.exclude_bias_and_norm(p):
+                    param_norm = torch.norm(p)
+                    update_norm = torch.norm(dp)
+                    one = torch.ones_like(param_norm)
+                    q = torch.where(param_norm > 0.,
+                                    torch.where(update_norm > 0,
+                                                (g['eta'] * param_norm / update_norm), one), one)
+                    dp = dp.mul(q)
+
+                param_state = self.state[p]
+                if 'mu' not in param_state:
+                    param_state['mu'] = torch.zeros_like(p)
+                mu = param_state['mu']
+                mu.mul_(g['momentum']).add_(dp)
+
+                p.add_(mu, alpha=-g['lr'])
+
+
+
+def adjust_learning_rate(args, optimizer, loader, step):
+    max_steps = args.epochs * len(loader)
+    warmup_steps = 10 * len(loader)
+    base_lr = args.batch_size / 256
+    if step < warmup_steps:
+        lr = base_lr * step / warmup_steps
+    else:
+        step -= warmup_steps
+        max_steps -= warmup_steps
+        q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
+        end_lr = base_lr * 0.001
+        lr = base_lr * q + end_lr * (1 - q)
+    optimizer.param_groups[0]['lr'] = lr * args.learning_rate_weights
+    optimizer.param_groups[1]['lr'] = lr * args.learning_rate_biases
+