a b/algorithms/moco.py
1
import torch
2
import copy
3
from torch import nn, optim
4
from typing import List, Optional, Tuple, Union
5
6
from algorithms.arch.resnet import loadResnetBackbone
7
import utilities.runUtils as rutl
8
9
def device_as(t1, t2):
10
    """
11
    Moves t1 to the device of t2
12
    """
13
    return t1.to(t2.device)
14
15
def deactivate_requires_grad(params):
16
    """Deactivates the requires_grad flag for all parameters.
17
18
    """
19
    for param in params:
20
        param.requires_grad = False
21
22
##==================== Model ===============================================
23
24
class ProjectionHead(nn.Module):
25
    """Base class for all projection and prediction heads.
26
    Args:
27
        blocks:
28
            List of tuples, each denoting one block of the projection head MLP.
29
            Each tuple reads (in_features, out_features, batch_norm_layer,
30
            non_linearity_layer).
31
    Examples:
32
        >>> # the following projection head has two blocks
33
        >>> # the first block uses batch norm an a ReLU non-linearity
34
        >>> # the second block is a simple linear layer
35
        >>> projection_head = ProjectionHead([
36
        >>>     (256, 256, nn.BatchNorm1d(256), nn.ReLU()),
37
        >>>     (256, 128, None, None)
38
        >>> ])
39
    """
40
41
    def __init__(
42
        self, blocks: List[Tuple[int, int, Optional[nn.Module], Optional[nn.Module]]]
43
    ):
44
        super(ProjectionHead, self).__init__()
45
46
        layers = []
47
        for input_dim, output_dim, batch_norm, non_linearity in blocks:
48
            use_bias = not bool(batch_norm)
49
            layers.append(nn.Linear(input_dim, output_dim, bias=use_bias))
50
            if batch_norm:
51
                layers.append(batch_norm)
52
            if non_linearity:
53
                layers.append(non_linearity)
54
        self.layers = nn.Sequential(*layers)
55
56
    def forward(self, x: torch.Tensor):
57
        """Computes one forward pass through the projection head.
58
        Args:
59
            x:
60
                Input of shape bsz x num_ftrs.
61
        """
62
        return self.layers(x)
63
64
class MoCoProjectionHead(ProjectionHead):
65
    """Projection head used for MoCo.
66
67
    "(...) we replace the fc head in MoCo with a 2-layer MLP head (hidden layer
68
    2048-d, with ReLU)" [0]
69
70
    [0]: MoCo, 2020, https://arxiv.org/abs/1911.05722
71
72
    """
73
74
    def __init__(self,
75
                 input_dim: int = 2048,
76
                 hidden_dim: int = 2048,
77
                 output_dim: int = 128):
78
        super(MoCoProjectionHead, self).__init__([
79
            (input_dim, hidden_dim, None, nn.ReLU()),
80
            (hidden_dim, output_dim, None, None),
81
        ])
82
83
class MoCo(nn.Module):
84
    def __init__(self, featx_arch, pretrained=None, backbone=None):
85
        super().__init__()
86
87
        if backbone is not None:
88
            self.backbone = backbone
89
        else:
90
            self.backbone, outfeatx_size = loadResnetBackbone(arch=featx_arch,
91
                                    torch_pretrain=pretrained)
92
        self.projection_head = MoCoProjectionHead(outfeatx_size, 2048, 128)
93
94
        self.backbone_momentum = copy.deepcopy(self.backbone)
95
        self.projection_head_momentum = copy.deepcopy(self.projection_head)
96
97
        deactivate_requires_grad(self.backbone_momentum.parameters())
98
        deactivate_requires_grad(self.projection_head_momentum.parameters())
99
100
    def forward(self, x):
101
        query = self.backbone(x).flatten(start_dim=1)
102
        query = self.projection_head(query)
103
        return query
104
105
    def forward_momentum(self, x):
106
        key = self.backbone_momentum(x).flatten(start_dim=1)
107
        key = self.projection_head_momentum(key).detach()
108
        return key