--- a +++ b/algorithms/byol.py @@ -0,0 +1,136 @@ +import torch +import copy +from torch import nn, optim +from typing import List, Optional, Tuple, Union + +from algorithms.arch.resnet import loadResnetBackbone +import utilities.runUtils as rutl + +def device_as(t1, t2): + """ + Moves t1 to the device of t2 + """ + return t1.to(t2.device) + + +##==================== Model =============================================== + +class ProjectionHead(nn.Module): + """Base class for all projection and prediction heads. + Args: + blocks: + List of tuples, each denoting one block of the projection head MLP. + Each tuple reads (in_features, out_features, batch_norm_layer, + non_linearity_layer). + Examples: + >>> # the following projection head has two blocks + >>> # the first block uses batch norm an a ReLU non-linearity + >>> # the second block is a simple linear layer + >>> projection_head = ProjectionHead([ + >>> (256, 256, nn.BatchNorm1d(256), nn.ReLU()), + >>> (256, 128, None, None) + >>> ]) + """ + + def __init__( + self, blocks: List[Tuple[int, int, Optional[nn.Module], Optional[nn.Module]]] + ): + super(ProjectionHead, self).__init__() + + layers = [] + for input_dim, output_dim, batch_norm, non_linearity in blocks: + use_bias = not bool(batch_norm) + layers.append(nn.Linear(input_dim, output_dim, bias=use_bias)) + if batch_norm: + layers.append(batch_norm) + if non_linearity: + layers.append(non_linearity) + self.layers = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor): + """Computes one forward pass through the projection head. + Args: + x: + Input of shape bsz x num_ftrs. + """ + return self.layers(x) + + +class BYOLProjectionHead(ProjectionHead): + """Projection head used for BYOL. + "This MLP consists in a linear layer with output size 4096 followed by + batch normalization, rectified linear units (ReLU), and a final + linear layer with output dimension 256." [0] + [0]: BYOL, 2020, https://arxiv.org/abs/2006.07733 + """ + + def __init__( + self, input_dim: int = 2048, hidden_dim: int = 4096, output_dim: int = 256 + ): + super(BYOLProjectionHead, self).__init__( + [ + (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), + (hidden_dim, output_dim, None, None), + ] + ) + + +class BYOLPredictionHead(ProjectionHead): + """Prediction head used for BYOL. + "This MLP consists in a linear layer with output size 4096 followed by + batch normalization, rectified linear units (ReLU), and a final + linear layer with output dimension 256." [0] + [0]: BYOL, 2020, https://arxiv.org/abs/2006.07733 + """ + + def __init__( + self, input_dim: int = 256, hidden_dim: int = 4096, output_dim: int = 256 + ): + super(BYOLPredictionHead, self).__init__( + [ + (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), + (hidden_dim, output_dim, None, None), + ] + ) + +def deactivate_requires_grad(model: nn.Module): + """Deactivates the requires_grad flag for all parameters of a model. + This has the same effect as permanently executing the model within a `torch.no_grad()` + context. Use this method to disable gradient computation and therefore + training for a model. + Examples: + >>> backbone = resnet18() + >>> deactivate_requires_grad(backbone) + """ + for param in model.parameters(): + param.requires_grad = False + +class BYOL(nn.Module): + def __init__(self, featx_arch, pretrained=None, backbone=None): + super().__init__() + + if backbone is not None: + self.backbone = backbone + else: + self.backbone, outfeatx_size = loadResnetBackbone(arch=featx_arch, + torch_pretrain=pretrained) + self.projection_head = BYOLProjectionHead(outfeatx_size, 1024, 256) + self.prediction_head = BYOLPredictionHead(256, 1024, 256) + + self.backbone_momentum = copy.deepcopy(self.backbone) + self.projection_head_momentum = copy.deepcopy(self.projection_head) + + deactivate_requires_grad(self.backbone_momentum) + deactivate_requires_grad(self.projection_head_momentum) + + def forward(self, x): + y = self.backbone(x).flatten(start_dim=1) + z = self.projection_head(y) + p = self.prediction_head(z) + return p + + def forward_momentum(self, x): + y = self.backbone_momentum(x).flatten(start_dim=1) + z = self.projection_head_momentum(y) + z = z.detach() + return z \ No newline at end of file