a b/algorithms/byol.py
1
import torch
2
import copy
3
from torch import nn, optim
4
from typing import List, Optional, Tuple, Union
5
6
from algorithms.arch.resnet import loadResnetBackbone
7
import utilities.runUtils as rutl
8
9
def device_as(t1, t2):
10
    """
11
    Moves t1 to the device of t2
12
    """
13
    return t1.to(t2.device)
14
15
16
##==================== Model ===============================================
17
18
class ProjectionHead(nn.Module):
19
    """Base class for all projection and prediction heads.
20
    Args:
21
        blocks:
22
            List of tuples, each denoting one block of the projection head MLP.
23
            Each tuple reads (in_features, out_features, batch_norm_layer,
24
            non_linearity_layer).
25
    Examples:
26
        >>> # the following projection head has two blocks
27
        >>> # the first block uses batch norm an a ReLU non-linearity
28
        >>> # the second block is a simple linear layer
29
        >>> projection_head = ProjectionHead([
30
        >>>     (256, 256, nn.BatchNorm1d(256), nn.ReLU()),
31
        >>>     (256, 128, None, None)
32
        >>> ])
33
    """
34
35
    def __init__(
36
        self, blocks: List[Tuple[int, int, Optional[nn.Module], Optional[nn.Module]]]
37
    ):
38
        super(ProjectionHead, self).__init__()
39
40
        layers = []
41
        for input_dim, output_dim, batch_norm, non_linearity in blocks:
42
            use_bias = not bool(batch_norm)
43
            layers.append(nn.Linear(input_dim, output_dim, bias=use_bias))
44
            if batch_norm:
45
                layers.append(batch_norm)
46
            if non_linearity:
47
                layers.append(non_linearity)
48
        self.layers = nn.Sequential(*layers)
49
50
    def forward(self, x: torch.Tensor):
51
        """Computes one forward pass through the projection head.
52
        Args:
53
            x:
54
                Input of shape bsz x num_ftrs.
55
        """
56
        return self.layers(x)
57
58
59
class BYOLProjectionHead(ProjectionHead):
60
    """Projection head used for BYOL.
61
    "This MLP consists in a linear layer with output size 4096 followed by
62
    batch normalization, rectified linear units (ReLU), and a final
63
    linear layer with output dimension 256." [0]
64
    [0]: BYOL, 2020, https://arxiv.org/abs/2006.07733
65
    """
66
67
    def __init__(
68
        self, input_dim: int = 2048, hidden_dim: int = 4096, output_dim: int = 256
69
    ):
70
        super(BYOLProjectionHead, self).__init__(
71
            [
72
                (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()),
73
                (hidden_dim, output_dim, None, None),
74
            ]
75
        )
76
77
78
class BYOLPredictionHead(ProjectionHead):
79
    """Prediction head used for BYOL.
80
    "This MLP consists in a linear layer with output size 4096 followed by
81
    batch normalization, rectified linear units (ReLU), and a final
82
    linear layer with output dimension 256." [0]
83
    [0]: BYOL, 2020, https://arxiv.org/abs/2006.07733
84
    """
85
86
    def __init__(
87
        self, input_dim: int = 256, hidden_dim: int = 4096, output_dim: int = 256
88
    ):
89
        super(BYOLPredictionHead, self).__init__(
90
            [
91
                (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()),
92
                (hidden_dim, output_dim, None, None),
93
            ]
94
        )
95
96
def deactivate_requires_grad(model: nn.Module):
97
    """Deactivates the requires_grad flag for all parameters of a model.
98
    This has the same effect as permanently executing the model within a `torch.no_grad()`
99
    context. Use this method to disable gradient computation and therefore
100
    training for a model.
101
    Examples:
102
        >>> backbone = resnet18()
103
        >>> deactivate_requires_grad(backbone)
104
    """
105
    for param in model.parameters():
106
        param.requires_grad = False
107
108
class BYOL(nn.Module):
109
    def __init__(self, featx_arch, pretrained=None, backbone=None):
110
        super().__init__()
111
112
        if backbone is not None:
113
            self.backbone = backbone
114
        else:
115
            self.backbone, outfeatx_size = loadResnetBackbone(arch=featx_arch,
116
                                    torch_pretrain=pretrained)
117
        self.projection_head = BYOLProjectionHead(outfeatx_size, 1024, 256)
118
        self.prediction_head = BYOLPredictionHead(256, 1024, 256)
119
120
        self.backbone_momentum = copy.deepcopy(self.backbone)
121
        self.projection_head_momentum = copy.deepcopy(self.projection_head)
122
123
        deactivate_requires_grad(self.backbone_momentum)
124
        deactivate_requires_grad(self.projection_head_momentum)
125
126
    def forward(self, x):
127
        y = self.backbone(x).flatten(start_dim=1)
128
        z = self.projection_head(y)
129
        p = self.prediction_head(z)
130
        return p
131
132
    def forward_momentum(self, x):
133
        y = self.backbone_momentum(x).flatten(start_dim=1)
134
        z = self.projection_head_momentum(y)
135
        z = z.detach()
136
        return z