Diff of /algorithms/vicreg.py [000000] .. [a18f15]

Switch to unified view

a b/algorithms/vicreg.py
1
import os, sys
2
import math
3
import torch
4
from torch import nn, optim
5
from torch.nn import functional as torch_F
6
7
sys.path.append(os.getcwd())
8
from algorithms.arch.resnet import loadResnetBackbone
9
10
## Codes from VIC-Reg official implementation with distributed training blocks removed
11
##==================== Model ===============================================
12
13
class VICReg(nn.Module):
14
    def __init__(self, featx_arch, projector_sizes,
15
                    batch_size, sim_coeff=25.0, std_coeff=25.0, cov_coeff=1.0,
16
                    featx_pretrain=None,
17
                 ):
18
        super().__init__()
19
20
        self.sim_coeff    = sim_coeff
21
        self.std_coeff    = std_coeff
22
        self.cov_coeff    = cov_coeff
23
        self.batch_size   = batch_size
24
25
        self.num_features = projector_sizes[-1]
26
        self.backbone, out_featx_size = loadResnetBackbone(
27
                            arch=featx_arch,torch_pretrain=featx_pretrain)
28
        self.projector = self.load_ProjectorNet(out_featx_size, projector_sizes)
29
30
31
    def forward(self, x, y):
32
        x = self.projector(self.backbone(x))
33
        y = self.projector(self.backbone(y))
34
35
        repr_loss = torch_F.mse_loss(x, y)
36
37
        x = x - x.mean(dim=0)
38
        y = y - y.mean(dim=0)
39
        std_x = torch.sqrt(x.var(dim=0) + 0.0001)
40
        std_y = torch.sqrt(y.var(dim=0) + 0.0001)
41
        std_loss = torch.mean(torch_F.relu(1 - std_x)) / 2 + \
42
                    torch.mean(torch_F.relu(1 - std_y)) / 2
43
44
        cov_x = (x.T @ x) / (self.batch_size - 1)
45
        cov_y = (y.T @ y) / (self.batch_size - 1)
46
        cov_loss = self.off_diagonal(cov_x).pow_(2).sum().div(
47
            self.num_features
48
        ) + self.off_diagonal(cov_y).pow_(2).sum().div(self.num_features)
49
50
        loss = (
51
            self.sim_coeff * repr_loss
52
            + self.std_coeff * std_loss
53
            + self.cov_coeff * cov_loss
54
        )
55
        return loss
56
57
    def load_ProjectorNet(self, outfeatx_size, projector_sizes):
58
        # backbone_out_shape + projector_dims
59
        sizes = [outfeatx_size] + list(projector_sizes)
60
        layers = []
61
        for i in range(len(sizes) - 2):
62
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
63
            layers.append(nn.BatchNorm1d(sizes[i + 1]))
64
            layers.append(nn.ReLU(inplace=True))
65
        layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
66
        projector = nn.Sequential(*layers)
67
        return projector
68
69
    def off_diagonal(self, x):
70
        n, m = x.shape
71
        assert n == m
72
        return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
73
74
75
76
##==================== OPTIMISER ===============================================
77
78
class LARS(optim.Optimizer):
79
    def __init__( self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001,
80
            weight_decay_filter=True,  lars_adaptation_filter=True,
81
                ):
82
        defaults = dict( lr=lr, weight_decay=weight_decay,
83
            momentum=momentum, eta=eta,
84
            weight_decay_filter=weight_decay_filter,
85
            lars_adaptation_filter=lars_adaptation_filter,
86
        )
87
        ## BT uses seperate params handling of weights and biases here
88
        super().__init__(params, defaults)
89
90
    def exclude_bias_and_norm(self, p):
91
        return p.ndim == 1
92
93
    @torch.no_grad()
94
    def step(self):
95
        for g in self.param_groups:
96
            for p in g["params"]:
97
                dp = p.grad
98
99
                if dp is None:
100
                    continue
101
102
                if not g['weight_decay_filter'] or not self.exclude_bias_and_norm(p):
103
                    dp = dp.add(p, alpha=g['weight_decay'])
104
105
                if not g['lars_adaptation_filter'] or not self.exclude_bias_and_norm(p):
106
                    param_norm = torch.norm(p)
107
                    update_norm = torch.norm(dp)
108
                    one = torch.ones_like(param_norm)
109
                    q = torch.where(
110
                        param_norm > 0.0,
111
                        torch.where(
112
                            update_norm > 0, (g["eta"] * param_norm / update_norm), one
113
                        ),
114
                        one,
115
                    )
116
                    dp = dp.mul(q)
117
118
                param_state = self.state[p]
119
                if "mu" not in param_state:
120
                    param_state["mu"] = torch.zeros_like(p)
121
                mu = param_state["mu"]
122
                mu.mul_(g["momentum"]).add_(dp)
123
124
                p.add_(mu, alpha=-g["lr"])
125
126
127
def adjust_learning_rate(args, optimizer, loader, step):
128
    max_steps = args.epochs * len(loader)
129
    warmup_steps = 10 * len(loader)
130
    ## BT does notn bother base LR
131
    base_lr = args.base_lr * args.batch_size / 256
132
    if step < warmup_steps:
133
        lr = base_lr * step / warmup_steps
134
    else:
135
        step -= warmup_steps
136
        max_steps -= warmup_steps
137
        q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
138
        end_lr = base_lr * 0.001
139
        lr = base_lr * q + end_lr * (1 - q)
140
    ## Handles weights and Biases seperately
141
    for param_group in optimizer.param_groups:
142
        param_group["lr"] = lr
143
    return lr
144
145
146
147
148
##==================== DEBUG ===============================================
149
150
if __name__ == "__main__":
151
152
    from torchinfo import summary
153
154
    model = VICReg( featx_arch='efficientnet_b0',
155
                    projector_sizes=[8192,8192,8192],
156
                    batch_size = 4,
157
                    featx_pretrain=None)
158
    summary(model, [(16, 3, 200, 200), (16, 3, 200, 200)])
159
    # print(model)