Diff of /algorithms/barlowtwins.py [000000] .. [a18f15]

Switch to unified view

a b/algorithms/barlowtwins.py
1
"""
2
Facebook (FAIR), released under MIT License
3
4
"""
5
import torch
6
import torchvision
7
from torch import nn, optim
8
import math
9
10
from algorithms.arch.resnet import loadResnetBackbone
11
import utilities.runUtils as rutl
12
13
## Codes from BarlowTwin official implementation with distributed training blocks removed
14
##==================== Model ===============================================
15
16
class BarlowTwins(nn.Module):
17
    def __init__(self, featx_arch, projector_sizes,
18
                 batch_size, lmbd = 0.0051, pretrained=None):
19
        super().__init__()
20
        rutl.START_SEED()
21
22
        self.batch_size =  batch_size
23
        self.lmbd = lmbd
24
25
        self.backbone, self.outfeatx_size = loadResnetBackbone(arch=featx_arch,
26
                                            torch_pretrain=pretrained)
27
28
        # backbone_out_shape + projector_dims
29
        sizes = [self.outfeatx_size] + list(projector_sizes)
30
        layers = []
31
        for i in range(len(sizes) - 2):
32
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
33
            layers.append(nn.BatchNorm1d(sizes[i + 1]))
34
            layers.append(nn.ReLU(inplace=True))
35
        layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
36
        self.projector = nn.Sequential(*layers)
37
38
        # normalization layer for the representations z1 and z2
39
        self.bn = nn.BatchNorm1d(sizes[-1], affine=False)
40
41
    def forward(self, y1, y2):
42
        z1 = self.projector(self.backbone(y1))
43
        z2 = self.projector(self.backbone(y2))
44
45
        # empirical cross-correlation matrix
46
        c = self.bn(z1).T @ self.bn(z2)
47
48
        # sum the cross-correlation matrix between all gpus
49
        c.div_(self.batch_size)
50
        # torch.distributed.all_reduce(c)
51
52
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
53
        off_diag = self.off_diagonal(c).pow_(2).sum()
54
        loss = on_diag + self.lmbd * off_diag
55
        return loss
56
57
58
    def off_diagonal(self, x):
59
        # return a flattened view of the off-diagonal elements of a square matrix
60
        n, m = x.shape
61
        assert n == m
62
        return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
63
64
65
##==================== OPTIMISER ===============================================
66
67
class LARS(optim.Optimizer):
68
    """ From Barlows twin example
69
    """
70
    def __init__(self, params, lr, weight_decay=0, momentum=0.9, eta=0.001,
71
                 weight_decay_filter=False, lars_adaptation_filter=False):
72
        defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
73
                        eta=eta, weight_decay_filter=weight_decay_filter,
74
                        lars_adaptation_filter=lars_adaptation_filter)
75
        param_weights = []
76
        param_biases = []
77
        for param in params:
78
            if param.ndim == 1:
79
                param_biases.append(param)
80
            else:
81
                param_weights.append(param)
82
        parameters = [{'params': param_weights}, {'params': param_biases}]
83
84
        super().__init__(parameters, defaults)
85
86
    def exclude_bias_and_norm(self, p):
87
        return p.ndim == 1
88
89
    @torch.no_grad()
90
    def step(self):
91
        for g in self.param_groups:
92
            for p in g['params']:
93
                dp = p.grad
94
95
                if dp is None:
96
                    continue
97
98
                if not g['weight_decay_filter'] or not self.exclude_bias_and_norm(p):
99
                    dp = dp.add(p, alpha=g['weight_decay'])
100
101
                if not g['lars_adaptation_filter'] or not self.exclude_bias_and_norm(p):
102
                    param_norm = torch.norm(p)
103
                    update_norm = torch.norm(dp)
104
                    one = torch.ones_like(param_norm)
105
                    q = torch.where(param_norm > 0.,
106
                                    torch.where(update_norm > 0,
107
                                                (g['eta'] * param_norm / update_norm), one), one)
108
                    dp = dp.mul(q)
109
110
                param_state = self.state[p]
111
                if 'mu' not in param_state:
112
                    param_state['mu'] = torch.zeros_like(p)
113
                mu = param_state['mu']
114
                mu.mul_(g['momentum']).add_(dp)
115
116
                p.add_(mu, alpha=-g['lr'])
117
118
119
120
def adjust_learning_rate(args, optimizer, loader, step):
121
    max_steps = args.epochs * len(loader)
122
    warmup_steps = 10 * len(loader)
123
    base_lr = args.batch_size / 256
124
    if step < warmup_steps:
125
        lr = base_lr * step / warmup_steps
126
    else:
127
        step -= warmup_steps
128
        max_steps -= warmup_steps
129
        q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
130
        end_lr = base_lr * 0.001
131
        lr = base_lr * q + end_lr * (1 - q)
132
    optimizer.param_groups[0]['lr'] = lr * args.learning_rate_weights
133
    optimizer.param_groups[1]['lr'] = lr * args.learning_rate_biases
134