|
a |
|
b/algorithms/barlowtwins.py |
|
|
1 |
""" |
|
|
2 |
Facebook (FAIR), released under MIT License |
|
|
3 |
|
|
|
4 |
""" |
|
|
5 |
import torch |
|
|
6 |
import torchvision |
|
|
7 |
from torch import nn, optim |
|
|
8 |
import math |
|
|
9 |
|
|
|
10 |
from algorithms.arch.resnet import loadResnetBackbone |
|
|
11 |
import utilities.runUtils as rutl |
|
|
12 |
|
|
|
13 |
## Codes from BarlowTwin official implementation with distributed training blocks removed |
|
|
14 |
##==================== Model =============================================== |
|
|
15 |
|
|
|
16 |
class BarlowTwins(nn.Module): |
|
|
17 |
def __init__(self, featx_arch, projector_sizes, |
|
|
18 |
batch_size, lmbd = 0.0051, pretrained=None): |
|
|
19 |
super().__init__() |
|
|
20 |
rutl.START_SEED() |
|
|
21 |
|
|
|
22 |
self.batch_size = batch_size |
|
|
23 |
self.lmbd = lmbd |
|
|
24 |
|
|
|
25 |
self.backbone, self.outfeatx_size = loadResnetBackbone(arch=featx_arch, |
|
|
26 |
torch_pretrain=pretrained) |
|
|
27 |
|
|
|
28 |
# backbone_out_shape + projector_dims |
|
|
29 |
sizes = [self.outfeatx_size] + list(projector_sizes) |
|
|
30 |
layers = [] |
|
|
31 |
for i in range(len(sizes) - 2): |
|
|
32 |
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False)) |
|
|
33 |
layers.append(nn.BatchNorm1d(sizes[i + 1])) |
|
|
34 |
layers.append(nn.ReLU(inplace=True)) |
|
|
35 |
layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False)) |
|
|
36 |
self.projector = nn.Sequential(*layers) |
|
|
37 |
|
|
|
38 |
# normalization layer for the representations z1 and z2 |
|
|
39 |
self.bn = nn.BatchNorm1d(sizes[-1], affine=False) |
|
|
40 |
|
|
|
41 |
def forward(self, y1, y2): |
|
|
42 |
z1 = self.projector(self.backbone(y1)) |
|
|
43 |
z2 = self.projector(self.backbone(y2)) |
|
|
44 |
|
|
|
45 |
# empirical cross-correlation matrix |
|
|
46 |
c = self.bn(z1).T @ self.bn(z2) |
|
|
47 |
|
|
|
48 |
# sum the cross-correlation matrix between all gpus |
|
|
49 |
c.div_(self.batch_size) |
|
|
50 |
# torch.distributed.all_reduce(c) |
|
|
51 |
|
|
|
52 |
on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() |
|
|
53 |
off_diag = self.off_diagonal(c).pow_(2).sum() |
|
|
54 |
loss = on_diag + self.lmbd * off_diag |
|
|
55 |
return loss |
|
|
56 |
|
|
|
57 |
|
|
|
58 |
def off_diagonal(self, x): |
|
|
59 |
# return a flattened view of the off-diagonal elements of a square matrix |
|
|
60 |
n, m = x.shape |
|
|
61 |
assert n == m |
|
|
62 |
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() |
|
|
63 |
|
|
|
64 |
|
|
|
65 |
##==================== OPTIMISER =============================================== |
|
|
66 |
|
|
|
67 |
class LARS(optim.Optimizer): |
|
|
68 |
""" From Barlows twin example |
|
|
69 |
""" |
|
|
70 |
def __init__(self, params, lr, weight_decay=0, momentum=0.9, eta=0.001, |
|
|
71 |
weight_decay_filter=False, lars_adaptation_filter=False): |
|
|
72 |
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, |
|
|
73 |
eta=eta, weight_decay_filter=weight_decay_filter, |
|
|
74 |
lars_adaptation_filter=lars_adaptation_filter) |
|
|
75 |
param_weights = [] |
|
|
76 |
param_biases = [] |
|
|
77 |
for param in params: |
|
|
78 |
if param.ndim == 1: |
|
|
79 |
param_biases.append(param) |
|
|
80 |
else: |
|
|
81 |
param_weights.append(param) |
|
|
82 |
parameters = [{'params': param_weights}, {'params': param_biases}] |
|
|
83 |
|
|
|
84 |
super().__init__(parameters, defaults) |
|
|
85 |
|
|
|
86 |
def exclude_bias_and_norm(self, p): |
|
|
87 |
return p.ndim == 1 |
|
|
88 |
|
|
|
89 |
@torch.no_grad() |
|
|
90 |
def step(self): |
|
|
91 |
for g in self.param_groups: |
|
|
92 |
for p in g['params']: |
|
|
93 |
dp = p.grad |
|
|
94 |
|
|
|
95 |
if dp is None: |
|
|
96 |
continue |
|
|
97 |
|
|
|
98 |
if not g['weight_decay_filter'] or not self.exclude_bias_and_norm(p): |
|
|
99 |
dp = dp.add(p, alpha=g['weight_decay']) |
|
|
100 |
|
|
|
101 |
if not g['lars_adaptation_filter'] or not self.exclude_bias_and_norm(p): |
|
|
102 |
param_norm = torch.norm(p) |
|
|
103 |
update_norm = torch.norm(dp) |
|
|
104 |
one = torch.ones_like(param_norm) |
|
|
105 |
q = torch.where(param_norm > 0., |
|
|
106 |
torch.where(update_norm > 0, |
|
|
107 |
(g['eta'] * param_norm / update_norm), one), one) |
|
|
108 |
dp = dp.mul(q) |
|
|
109 |
|
|
|
110 |
param_state = self.state[p] |
|
|
111 |
if 'mu' not in param_state: |
|
|
112 |
param_state['mu'] = torch.zeros_like(p) |
|
|
113 |
mu = param_state['mu'] |
|
|
114 |
mu.mul_(g['momentum']).add_(dp) |
|
|
115 |
|
|
|
116 |
p.add_(mu, alpha=-g['lr']) |
|
|
117 |
|
|
|
118 |
|
|
|
119 |
|
|
|
120 |
def adjust_learning_rate(args, optimizer, loader, step): |
|
|
121 |
max_steps = args.epochs * len(loader) |
|
|
122 |
warmup_steps = 10 * len(loader) |
|
|
123 |
base_lr = args.batch_size / 256 |
|
|
124 |
if step < warmup_steps: |
|
|
125 |
lr = base_lr * step / warmup_steps |
|
|
126 |
else: |
|
|
127 |
step -= warmup_steps |
|
|
128 |
max_steps -= warmup_steps |
|
|
129 |
q = 0.5 * (1 + math.cos(math.pi * step / max_steps)) |
|
|
130 |
end_lr = base_lr * 0.001 |
|
|
131 |
lr = base_lr * q + end_lr * (1 - q) |
|
|
132 |
optimizer.param_groups[0]['lr'] = lr * args.learning_rate_weights |
|
|
133 |
optimizer.param_groups[1]['lr'] = lr * args.learning_rate_biases |
|
|
134 |
|