Diff of /algorithms/moco.py [000000] .. [a18f15]

Switch to side-by-side view

--- a
+++ b/algorithms/moco.py
@@ -0,0 +1,108 @@
+import torch
+import copy
+from torch import nn, optim
+from typing import List, Optional, Tuple, Union
+
+from algorithms.arch.resnet import loadResnetBackbone
+import utilities.runUtils as rutl
+
+def device_as(t1, t2):
+    """
+    Moves t1 to the device of t2
+    """
+    return t1.to(t2.device)
+
+def deactivate_requires_grad(params):
+    """Deactivates the requires_grad flag for all parameters.
+
+    """
+    for param in params:
+        param.requires_grad = False
+
+##==================== Model ===============================================
+
+class ProjectionHead(nn.Module):
+    """Base class for all projection and prediction heads.
+    Args:
+        blocks:
+            List of tuples, each denoting one block of the projection head MLP.
+            Each tuple reads (in_features, out_features, batch_norm_layer,
+            non_linearity_layer).
+    Examples:
+        >>> # the following projection head has two blocks
+        >>> # the first block uses batch norm an a ReLU non-linearity
+        >>> # the second block is a simple linear layer
+        >>> projection_head = ProjectionHead([
+        >>>     (256, 256, nn.BatchNorm1d(256), nn.ReLU()),
+        >>>     (256, 128, None, None)
+        >>> ])
+    """
+
+    def __init__(
+        self, blocks: List[Tuple[int, int, Optional[nn.Module], Optional[nn.Module]]]
+    ):
+        super(ProjectionHead, self).__init__()
+
+        layers = []
+        for input_dim, output_dim, batch_norm, non_linearity in blocks:
+            use_bias = not bool(batch_norm)
+            layers.append(nn.Linear(input_dim, output_dim, bias=use_bias))
+            if batch_norm:
+                layers.append(batch_norm)
+            if non_linearity:
+                layers.append(non_linearity)
+        self.layers = nn.Sequential(*layers)
+
+    def forward(self, x: torch.Tensor):
+        """Computes one forward pass through the projection head.
+        Args:
+            x:
+                Input of shape bsz x num_ftrs.
+        """
+        return self.layers(x)
+
+class MoCoProjectionHead(ProjectionHead):
+    """Projection head used for MoCo.
+
+    "(...) we replace the fc head in MoCo with a 2-layer MLP head (hidden layer
+    2048-d, with ReLU)" [0]
+
+    [0]: MoCo, 2020, https://arxiv.org/abs/1911.05722
+
+    """
+
+    def __init__(self,
+                 input_dim: int = 2048,
+                 hidden_dim: int = 2048,
+                 output_dim: int = 128):
+        super(MoCoProjectionHead, self).__init__([
+            (input_dim, hidden_dim, None, nn.ReLU()),
+            (hidden_dim, output_dim, None, None),
+        ])
+
+class MoCo(nn.Module):
+    def __init__(self, featx_arch, pretrained=None, backbone=None):
+        super().__init__()
+
+        if backbone is not None:
+            self.backbone = backbone
+        else:
+            self.backbone, outfeatx_size = loadResnetBackbone(arch=featx_arch,
+                                    torch_pretrain=pretrained)
+        self.projection_head = MoCoProjectionHead(outfeatx_size, 2048, 128)
+
+        self.backbone_momentum = copy.deepcopy(self.backbone)
+        self.projection_head_momentum = copy.deepcopy(self.projection_head)
+
+        deactivate_requires_grad(self.backbone_momentum.parameters())
+        deactivate_requires_grad(self.projection_head_momentum.parameters())
+
+    def forward(self, x):
+        query = self.backbone(x).flatten(start_dim=1)
+        query = self.projection_head(query)
+        return query
+
+    def forward_momentum(self, x):
+        key = self.backbone_momentum(x).flatten(start_dim=1)
+        key = self.projection_head_momentum(key).detach()
+        return key
\ No newline at end of file