Diff of /algorithms/simclr.py [000000] .. [a18f15]

Switch to side-by-side view

--- a
+++ b/algorithms/simclr.py
@@ -0,0 +1,169 @@
+import math
+import torch
+from torch import nn, optim
+
+from algorithms.arch.resnet import loadResnetBackbone
+import utilities.runUtils as rutl
+
+def device_as(t1, t2):
+    """
+    Moves t1 to the device of t2
+    """
+    return t1.to(t2.device)
+
+##==================== Model ===============================================
+  
+class ContrastiveLoss(nn.Module):
+    """
+    Vanilla Contrastive loss, also called InfoNceLoss as in SimCLR paper
+    """
+    def __init__(self, batch_size, temperature=0.5):
+        super().__init__()
+        self.batch_size = batch_size
+        self.temperature = temperature
+        self.mask = (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool)).float()
+
+    def calc_similarity_batch(self, a, b):
+        representations = torch.cat([a, b], dim=0)
+        return nn.functional.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
+
+    def forward(self, proj_1, proj_2):
+        """
+        proj_1 and proj_2 are batched embeddings [batch, embedding_dim]
+        where corresponding indices are pairs
+        z_i, z_j in the SimCLR paper
+        """
+        batch_size = proj_1.shape[0]
+        z_i = nn.functional.normalize(proj_1, p=2, dim=1)
+        z_j = nn.functional.normalize(proj_2, p=2, dim=1)
+
+        similarity_matrix = self.calc_similarity_batch(z_i, z_j)
+
+        sim_ij = torch.diag(similarity_matrix, batch_size)
+        sim_ji = torch.diag(similarity_matrix, -batch_size)
+
+        positives = torch.cat([sim_ij, sim_ji], dim=0)
+        
+        nominator = torch.exp(positives / self.temperature)
+
+        denominator = device_as(self.mask, similarity_matrix) * torch.exp(similarity_matrix / self.temperature)
+
+        all_losses = -torch.log(nominator / torch.sum(denominator, dim=1))
+        loss = torch.sum(all_losses) / (2 * self.batch_size)
+        return loss  
+
+class SimCLR(nn.Module):
+    def __init__(self, featx_arch, projector_sizes,
+                 batch_size, temp, pretrained=None):
+        super().__init__()
+        rutl.START_SEED()
+
+        mlp_dim = projector_sizes[0]
+        embedding_size = projector_sizes[1]
+        
+        self.batch_size = batch_size
+        self.temp = temp
+        self.backbone, outfeatx_size = loadResnetBackbone(arch=featx_arch,
+                                            torch_pretrain=pretrained)
+        # add mlp projection head
+        self.projector = nn.Sequential(
+            nn.Linear(in_features=outfeatx_size, out_features=mlp_dim),
+            nn.BatchNorm1d(mlp_dim),
+            nn.ReLU(),
+            nn.Linear(in_features=mlp_dim, out_features=embedding_size),
+            # nn.BatchNorm1d(embedding_size),
+        )
+    def forward(self, y1, y2):
+        z1 = self.projector(self.backbone(y1))
+        z2 = self.projector(self.backbone(y2))
+        loss = ContrastiveLoss(self.batch_size, self.temp)
+        return loss(z1, z2)
+
+##==================== OPTIMISER ===============================================
+
+class LARS(optim.Optimizer):
+    def __init__(
+        self,
+        params,
+        lr,
+        momentum=0,
+        dampening=0,
+        weight_decay=0,
+        nesterov=False,
+        trust_coefficient=0.001,
+        eps=1e-8,
+    ):
+
+        defaults = dict(
+            lr=lr,
+            momentum=momentum,
+            dampening=dampening,
+            weight_decay=weight_decay,
+            nesterov=nesterov,
+            trust_coefficient=trust_coefficient,
+            eps=eps,
+        )
+        if nesterov and (momentum <= 0 or dampening != 0):
+            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
+
+        super().__init__(params, defaults)
+
+    def __setstate__(self, state):
+        super().__setstate__(state)
+
+        for group in self.param_groups:
+            group.setdefault("nesterov", False)
+
+    @torch.no_grad()
+    def step(self, closure=None):
+        """Performs a single optimization step.
+
+        Args:
+            closure (callable, optional): A closure that reevaluates the model
+                and returns the loss.
+        """
+        loss = None
+        if closure is not None:
+            with torch.enable_grad():
+                loss = closure()
+
+        # exclude scaling for params with 0 weight decay
+        for group in self.param_groups:
+            weight_decay = group["weight_decay"]
+            momentum = group["momentum"]
+            dampening = group["dampening"]
+            nesterov = group["nesterov"]
+
+            for p in group["params"]:
+                if p.grad is None:
+                    continue
+
+                d_p = p.grad
+                p_norm = torch.norm(p.data)
+                g_norm = torch.norm(p.grad.data)
+
+                # lars scaling + weight decay part
+                if weight_decay != 0:
+                    if p_norm != 0 and g_norm != 0:
+                        lars_lr = p_norm / (g_norm + p_norm * weight_decay + group["eps"])
+                        lars_lr *= group["trust_coefficient"]
+
+                        d_p = d_p.add(p, alpha=weight_decay)
+                        d_p *= lars_lr
+
+                # sgd part
+                if momentum != 0:
+                    param_state = self.state[p]
+                    if "momentum_buffer" not in param_state:
+                        buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
+                    else:
+                        buf = param_state["momentum_buffer"]
+                        buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
+                    if nesterov:
+                        d_p = d_p.add(buf, alpha=momentum)
+                    else:
+                        d_p = buf
+
+                p.add_(d_p, alpha=-group["lr"])
+
+        return loss
\ No newline at end of file