Diff of /algorithms/byol.py [000000] .. [a18f15]

Switch to side-by-side view

--- a
+++ b/algorithms/byol.py
@@ -0,0 +1,136 @@
+import torch
+import copy
+from torch import nn, optim
+from typing import List, Optional, Tuple, Union
+
+from algorithms.arch.resnet import loadResnetBackbone
+import utilities.runUtils as rutl
+
+def device_as(t1, t2):
+    """
+    Moves t1 to the device of t2
+    """
+    return t1.to(t2.device)
+
+
+##==================== Model ===============================================
+
+class ProjectionHead(nn.Module):
+    """Base class for all projection and prediction heads.
+    Args:
+        blocks:
+            List of tuples, each denoting one block of the projection head MLP.
+            Each tuple reads (in_features, out_features, batch_norm_layer,
+            non_linearity_layer).
+    Examples:
+        >>> # the following projection head has two blocks
+        >>> # the first block uses batch norm an a ReLU non-linearity
+        >>> # the second block is a simple linear layer
+        >>> projection_head = ProjectionHead([
+        >>>     (256, 256, nn.BatchNorm1d(256), nn.ReLU()),
+        >>>     (256, 128, None, None)
+        >>> ])
+    """
+
+    def __init__(
+        self, blocks: List[Tuple[int, int, Optional[nn.Module], Optional[nn.Module]]]
+    ):
+        super(ProjectionHead, self).__init__()
+
+        layers = []
+        for input_dim, output_dim, batch_norm, non_linearity in blocks:
+            use_bias = not bool(batch_norm)
+            layers.append(nn.Linear(input_dim, output_dim, bias=use_bias))
+            if batch_norm:
+                layers.append(batch_norm)
+            if non_linearity:
+                layers.append(non_linearity)
+        self.layers = nn.Sequential(*layers)
+
+    def forward(self, x: torch.Tensor):
+        """Computes one forward pass through the projection head.
+        Args:
+            x:
+                Input of shape bsz x num_ftrs.
+        """
+        return self.layers(x)
+
+
+class BYOLProjectionHead(ProjectionHead):
+    """Projection head used for BYOL.
+    "This MLP consists in a linear layer with output size 4096 followed by
+    batch normalization, rectified linear units (ReLU), and a final
+    linear layer with output dimension 256." [0]
+    [0]: BYOL, 2020, https://arxiv.org/abs/2006.07733
+    """
+
+    def __init__(
+        self, input_dim: int = 2048, hidden_dim: int = 4096, output_dim: int = 256
+    ):
+        super(BYOLProjectionHead, self).__init__(
+            [
+                (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()),
+                (hidden_dim, output_dim, None, None),
+            ]
+        )
+
+
+class BYOLPredictionHead(ProjectionHead):
+    """Prediction head used for BYOL.
+    "This MLP consists in a linear layer with output size 4096 followed by
+    batch normalization, rectified linear units (ReLU), and a final
+    linear layer with output dimension 256." [0]
+    [0]: BYOL, 2020, https://arxiv.org/abs/2006.07733
+    """
+
+    def __init__(
+        self, input_dim: int = 256, hidden_dim: int = 4096, output_dim: int = 256
+    ):
+        super(BYOLPredictionHead, self).__init__(
+            [
+                (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()),
+                (hidden_dim, output_dim, None, None),
+            ]
+        )
+
+def deactivate_requires_grad(model: nn.Module):
+    """Deactivates the requires_grad flag for all parameters of a model.
+    This has the same effect as permanently executing the model within a `torch.no_grad()`
+    context. Use this method to disable gradient computation and therefore
+    training for a model.
+    Examples:
+        >>> backbone = resnet18()
+        >>> deactivate_requires_grad(backbone)
+    """
+    for param in model.parameters():
+        param.requires_grad = False
+
+class BYOL(nn.Module):
+    def __init__(self, featx_arch, pretrained=None, backbone=None):
+        super().__init__()
+
+        if backbone is not None:
+            self.backbone = backbone
+        else:
+            self.backbone, outfeatx_size = loadResnetBackbone(arch=featx_arch,
+                                    torch_pretrain=pretrained)
+        self.projection_head = BYOLProjectionHead(outfeatx_size, 1024, 256)
+        self.prediction_head = BYOLPredictionHead(256, 1024, 256)
+
+        self.backbone_momentum = copy.deepcopy(self.backbone)
+        self.projection_head_momentum = copy.deepcopy(self.projection_head)
+
+        deactivate_requires_grad(self.backbone_momentum)
+        deactivate_requires_grad(self.projection_head_momentum)
+
+    def forward(self, x):
+        y = self.backbone(x).flatten(start_dim=1)
+        z = self.projection_head(y)
+        p = self.prediction_head(z)
+        return p
+
+    def forward_momentum(self, x):
+        y = self.backbone_momentum(x).flatten(start_dim=1)
+        z = self.projection_head_momentum(y)
+        z = z.detach()
+        return z
\ No newline at end of file