Diff of /algorithms/simclr.py [000000] .. [a18f15]

Switch to unified view

a b/algorithms/simclr.py
1
import math
2
import torch
3
from torch import nn, optim
4
5
from algorithms.arch.resnet import loadResnetBackbone
6
import utilities.runUtils as rutl
7
8
def device_as(t1, t2):
9
    """
10
    Moves t1 to the device of t2
11
    """
12
    return t1.to(t2.device)
13
14
##==================== Model ===============================================
15
  
16
class ContrastiveLoss(nn.Module):
17
    """
18
    Vanilla Contrastive loss, also called InfoNceLoss as in SimCLR paper
19
    """
20
    def __init__(self, batch_size, temperature=0.5):
21
        super().__init__()
22
        self.batch_size = batch_size
23
        self.temperature = temperature
24
        self.mask = (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool)).float()
25
26
    def calc_similarity_batch(self, a, b):
27
        representations = torch.cat([a, b], dim=0)
28
        return nn.functional.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
29
30
    def forward(self, proj_1, proj_2):
31
        """
32
        proj_1 and proj_2 are batched embeddings [batch, embedding_dim]
33
        where corresponding indices are pairs
34
        z_i, z_j in the SimCLR paper
35
        """
36
        batch_size = proj_1.shape[0]
37
        z_i = nn.functional.normalize(proj_1, p=2, dim=1)
38
        z_j = nn.functional.normalize(proj_2, p=2, dim=1)
39
40
        similarity_matrix = self.calc_similarity_batch(z_i, z_j)
41
42
        sim_ij = torch.diag(similarity_matrix, batch_size)
43
        sim_ji = torch.diag(similarity_matrix, -batch_size)
44
45
        positives = torch.cat([sim_ij, sim_ji], dim=0)
46
        
47
        nominator = torch.exp(positives / self.temperature)
48
49
        denominator = device_as(self.mask, similarity_matrix) * torch.exp(similarity_matrix / self.temperature)
50
51
        all_losses = -torch.log(nominator / torch.sum(denominator, dim=1))
52
        loss = torch.sum(all_losses) / (2 * self.batch_size)
53
        return loss  
54
55
class SimCLR(nn.Module):
56
    def __init__(self, featx_arch, projector_sizes,
57
                 batch_size, temp, pretrained=None):
58
        super().__init__()
59
        rutl.START_SEED()
60
61
        mlp_dim = projector_sizes[0]
62
        embedding_size = projector_sizes[1]
63
        
64
        self.batch_size = batch_size
65
        self.temp = temp
66
        self.backbone, outfeatx_size = loadResnetBackbone(arch=featx_arch,
67
                                            torch_pretrain=pretrained)
68
        # add mlp projection head
69
        self.projector = nn.Sequential(
70
            nn.Linear(in_features=outfeatx_size, out_features=mlp_dim),
71
            nn.BatchNorm1d(mlp_dim),
72
            nn.ReLU(),
73
            nn.Linear(in_features=mlp_dim, out_features=embedding_size),
74
            # nn.BatchNorm1d(embedding_size),
75
        )
76
    def forward(self, y1, y2):
77
        z1 = self.projector(self.backbone(y1))
78
        z2 = self.projector(self.backbone(y2))
79
        loss = ContrastiveLoss(self.batch_size, self.temp)
80
        return loss(z1, z2)
81
82
##==================== OPTIMISER ===============================================
83
84
class LARS(optim.Optimizer):
85
    def __init__(
86
        self,
87
        params,
88
        lr,
89
        momentum=0,
90
        dampening=0,
91
        weight_decay=0,
92
        nesterov=False,
93
        trust_coefficient=0.001,
94
        eps=1e-8,
95
    ):
96
97
        defaults = dict(
98
            lr=lr,
99
            momentum=momentum,
100
            dampening=dampening,
101
            weight_decay=weight_decay,
102
            nesterov=nesterov,
103
            trust_coefficient=trust_coefficient,
104
            eps=eps,
105
        )
106
        if nesterov and (momentum <= 0 or dampening != 0):
107
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
108
109
        super().__init__(params, defaults)
110
111
    def __setstate__(self, state):
112
        super().__setstate__(state)
113
114
        for group in self.param_groups:
115
            group.setdefault("nesterov", False)
116
117
    @torch.no_grad()
118
    def step(self, closure=None):
119
        """Performs a single optimization step.
120
121
        Args:
122
            closure (callable, optional): A closure that reevaluates the model
123
                and returns the loss.
124
        """
125
        loss = None
126
        if closure is not None:
127
            with torch.enable_grad():
128
                loss = closure()
129
130
        # exclude scaling for params with 0 weight decay
131
        for group in self.param_groups:
132
            weight_decay = group["weight_decay"]
133
            momentum = group["momentum"]
134
            dampening = group["dampening"]
135
            nesterov = group["nesterov"]
136
137
            for p in group["params"]:
138
                if p.grad is None:
139
                    continue
140
141
                d_p = p.grad
142
                p_norm = torch.norm(p.data)
143
                g_norm = torch.norm(p.grad.data)
144
145
                # lars scaling + weight decay part
146
                if weight_decay != 0:
147
                    if p_norm != 0 and g_norm != 0:
148
                        lars_lr = p_norm / (g_norm + p_norm * weight_decay + group["eps"])
149
                        lars_lr *= group["trust_coefficient"]
150
151
                        d_p = d_p.add(p, alpha=weight_decay)
152
                        d_p *= lars_lr
153
154
                # sgd part
155
                if momentum != 0:
156
                    param_state = self.state[p]
157
                    if "momentum_buffer" not in param_state:
158
                        buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
159
                    else:
160
                        buf = param_state["momentum_buffer"]
161
                        buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
162
                    if nesterov:
163
                        d_p = d_p.add(buf, alpha=momentum)
164
                    else:
165
                        d_p = buf
166
167
                p.add_(d_p, alpha=-group["lr"])
168
169
        return loss