|
a |
|
b/algorithms/vicreg.py |
|
|
1 |
import os, sys |
|
|
2 |
import math |
|
|
3 |
import torch |
|
|
4 |
from torch import nn, optim |
|
|
5 |
from torch.nn import functional as torch_F |
|
|
6 |
|
|
|
7 |
sys.path.append(os.getcwd()) |
|
|
8 |
from algorithms.arch.resnet import loadResnetBackbone |
|
|
9 |
|
|
|
10 |
## Codes from VIC-Reg official implementation with distributed training blocks removed |
|
|
11 |
##==================== Model =============================================== |
|
|
12 |
|
|
|
13 |
class VICReg(nn.Module): |
|
|
14 |
def __init__(self, featx_arch, projector_sizes, |
|
|
15 |
batch_size, sim_coeff=25.0, std_coeff=25.0, cov_coeff=1.0, |
|
|
16 |
featx_pretrain=None, |
|
|
17 |
): |
|
|
18 |
super().__init__() |
|
|
19 |
|
|
|
20 |
self.sim_coeff = sim_coeff |
|
|
21 |
self.std_coeff = std_coeff |
|
|
22 |
self.cov_coeff = cov_coeff |
|
|
23 |
self.batch_size = batch_size |
|
|
24 |
|
|
|
25 |
self.num_features = projector_sizes[-1] |
|
|
26 |
self.backbone, out_featx_size = loadResnetBackbone( |
|
|
27 |
arch=featx_arch,torch_pretrain=featx_pretrain) |
|
|
28 |
self.projector = self.load_ProjectorNet(out_featx_size, projector_sizes) |
|
|
29 |
|
|
|
30 |
|
|
|
31 |
def forward(self, x, y): |
|
|
32 |
x = self.projector(self.backbone(x)) |
|
|
33 |
y = self.projector(self.backbone(y)) |
|
|
34 |
|
|
|
35 |
repr_loss = torch_F.mse_loss(x, y) |
|
|
36 |
|
|
|
37 |
x = x - x.mean(dim=0) |
|
|
38 |
y = y - y.mean(dim=0) |
|
|
39 |
std_x = torch.sqrt(x.var(dim=0) + 0.0001) |
|
|
40 |
std_y = torch.sqrt(y.var(dim=0) + 0.0001) |
|
|
41 |
std_loss = torch.mean(torch_F.relu(1 - std_x)) / 2 + \ |
|
|
42 |
torch.mean(torch_F.relu(1 - std_y)) / 2 |
|
|
43 |
|
|
|
44 |
cov_x = (x.T @ x) / (self.batch_size - 1) |
|
|
45 |
cov_y = (y.T @ y) / (self.batch_size - 1) |
|
|
46 |
cov_loss = self.off_diagonal(cov_x).pow_(2).sum().div( |
|
|
47 |
self.num_features |
|
|
48 |
) + self.off_diagonal(cov_y).pow_(2).sum().div(self.num_features) |
|
|
49 |
|
|
|
50 |
loss = ( |
|
|
51 |
self.sim_coeff * repr_loss |
|
|
52 |
+ self.std_coeff * std_loss |
|
|
53 |
+ self.cov_coeff * cov_loss |
|
|
54 |
) |
|
|
55 |
return loss |
|
|
56 |
|
|
|
57 |
def load_ProjectorNet(self, outfeatx_size, projector_sizes): |
|
|
58 |
# backbone_out_shape + projector_dims |
|
|
59 |
sizes = [outfeatx_size] + list(projector_sizes) |
|
|
60 |
layers = [] |
|
|
61 |
for i in range(len(sizes) - 2): |
|
|
62 |
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False)) |
|
|
63 |
layers.append(nn.BatchNorm1d(sizes[i + 1])) |
|
|
64 |
layers.append(nn.ReLU(inplace=True)) |
|
|
65 |
layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False)) |
|
|
66 |
projector = nn.Sequential(*layers) |
|
|
67 |
return projector |
|
|
68 |
|
|
|
69 |
def off_diagonal(self, x): |
|
|
70 |
n, m = x.shape |
|
|
71 |
assert n == m |
|
|
72 |
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() |
|
|
73 |
|
|
|
74 |
|
|
|
75 |
|
|
|
76 |
##==================== OPTIMISER =============================================== |
|
|
77 |
|
|
|
78 |
class LARS(optim.Optimizer): |
|
|
79 |
def __init__( self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001, |
|
|
80 |
weight_decay_filter=True, lars_adaptation_filter=True, |
|
|
81 |
): |
|
|
82 |
defaults = dict( lr=lr, weight_decay=weight_decay, |
|
|
83 |
momentum=momentum, eta=eta, |
|
|
84 |
weight_decay_filter=weight_decay_filter, |
|
|
85 |
lars_adaptation_filter=lars_adaptation_filter, |
|
|
86 |
) |
|
|
87 |
## BT uses seperate params handling of weights and biases here |
|
|
88 |
super().__init__(params, defaults) |
|
|
89 |
|
|
|
90 |
def exclude_bias_and_norm(self, p): |
|
|
91 |
return p.ndim == 1 |
|
|
92 |
|
|
|
93 |
@torch.no_grad() |
|
|
94 |
def step(self): |
|
|
95 |
for g in self.param_groups: |
|
|
96 |
for p in g["params"]: |
|
|
97 |
dp = p.grad |
|
|
98 |
|
|
|
99 |
if dp is None: |
|
|
100 |
continue |
|
|
101 |
|
|
|
102 |
if not g['weight_decay_filter'] or not self.exclude_bias_and_norm(p): |
|
|
103 |
dp = dp.add(p, alpha=g['weight_decay']) |
|
|
104 |
|
|
|
105 |
if not g['lars_adaptation_filter'] or not self.exclude_bias_and_norm(p): |
|
|
106 |
param_norm = torch.norm(p) |
|
|
107 |
update_norm = torch.norm(dp) |
|
|
108 |
one = torch.ones_like(param_norm) |
|
|
109 |
q = torch.where( |
|
|
110 |
param_norm > 0.0, |
|
|
111 |
torch.where( |
|
|
112 |
update_norm > 0, (g["eta"] * param_norm / update_norm), one |
|
|
113 |
), |
|
|
114 |
one, |
|
|
115 |
) |
|
|
116 |
dp = dp.mul(q) |
|
|
117 |
|
|
|
118 |
param_state = self.state[p] |
|
|
119 |
if "mu" not in param_state: |
|
|
120 |
param_state["mu"] = torch.zeros_like(p) |
|
|
121 |
mu = param_state["mu"] |
|
|
122 |
mu.mul_(g["momentum"]).add_(dp) |
|
|
123 |
|
|
|
124 |
p.add_(mu, alpha=-g["lr"]) |
|
|
125 |
|
|
|
126 |
|
|
|
127 |
def adjust_learning_rate(args, optimizer, loader, step): |
|
|
128 |
max_steps = args.epochs * len(loader) |
|
|
129 |
warmup_steps = 10 * len(loader) |
|
|
130 |
## BT does notn bother base LR |
|
|
131 |
base_lr = args.base_lr * args.batch_size / 256 |
|
|
132 |
if step < warmup_steps: |
|
|
133 |
lr = base_lr * step / warmup_steps |
|
|
134 |
else: |
|
|
135 |
step -= warmup_steps |
|
|
136 |
max_steps -= warmup_steps |
|
|
137 |
q = 0.5 * (1 + math.cos(math.pi * step / max_steps)) |
|
|
138 |
end_lr = base_lr * 0.001 |
|
|
139 |
lr = base_lr * q + end_lr * (1 - q) |
|
|
140 |
## Handles weights and Biases seperately |
|
|
141 |
for param_group in optimizer.param_groups: |
|
|
142 |
param_group["lr"] = lr |
|
|
143 |
return lr |
|
|
144 |
|
|
|
145 |
|
|
|
146 |
|
|
|
147 |
|
|
|
148 |
##==================== DEBUG =============================================== |
|
|
149 |
|
|
|
150 |
if __name__ == "__main__": |
|
|
151 |
|
|
|
152 |
from torchinfo import summary |
|
|
153 |
|
|
|
154 |
model = VICReg( featx_arch='efficientnet_b0', |
|
|
155 |
projector_sizes=[8192,8192,8192], |
|
|
156 |
batch_size = 4, |
|
|
157 |
featx_pretrain=None) |
|
|
158 |
summary(model, [(16, 3, 200, 200), (16, 3, 200, 200)]) |
|
|
159 |
# print(model) |