--- a +++ b/algorithms/moco.py @@ -0,0 +1,108 @@ +import torch +import copy +from torch import nn, optim +from typing import List, Optional, Tuple, Union + +from algorithms.arch.resnet import loadResnetBackbone +import utilities.runUtils as rutl + +def device_as(t1, t2): + """ + Moves t1 to the device of t2 + """ + return t1.to(t2.device) + +def deactivate_requires_grad(params): + """Deactivates the requires_grad flag for all parameters. + + """ + for param in params: + param.requires_grad = False + +##==================== Model =============================================== + +class ProjectionHead(nn.Module): + """Base class for all projection and prediction heads. + Args: + blocks: + List of tuples, each denoting one block of the projection head MLP. + Each tuple reads (in_features, out_features, batch_norm_layer, + non_linearity_layer). + Examples: + >>> # the following projection head has two blocks + >>> # the first block uses batch norm an a ReLU non-linearity + >>> # the second block is a simple linear layer + >>> projection_head = ProjectionHead([ + >>> (256, 256, nn.BatchNorm1d(256), nn.ReLU()), + >>> (256, 128, None, None) + >>> ]) + """ + + def __init__( + self, blocks: List[Tuple[int, int, Optional[nn.Module], Optional[nn.Module]]] + ): + super(ProjectionHead, self).__init__() + + layers = [] + for input_dim, output_dim, batch_norm, non_linearity in blocks: + use_bias = not bool(batch_norm) + layers.append(nn.Linear(input_dim, output_dim, bias=use_bias)) + if batch_norm: + layers.append(batch_norm) + if non_linearity: + layers.append(non_linearity) + self.layers = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor): + """Computes one forward pass through the projection head. + Args: + x: + Input of shape bsz x num_ftrs. + """ + return self.layers(x) + +class MoCoProjectionHead(ProjectionHead): + """Projection head used for MoCo. + + "(...) we replace the fc head in MoCo with a 2-layer MLP head (hidden layer + 2048-d, with ReLU)" [0] + + [0]: MoCo, 2020, https://arxiv.org/abs/1911.05722 + + """ + + def __init__(self, + input_dim: int = 2048, + hidden_dim: int = 2048, + output_dim: int = 128): + super(MoCoProjectionHead, self).__init__([ + (input_dim, hidden_dim, None, nn.ReLU()), + (hidden_dim, output_dim, None, None), + ]) + +class MoCo(nn.Module): + def __init__(self, featx_arch, pretrained=None, backbone=None): + super().__init__() + + if backbone is not None: + self.backbone = backbone + else: + self.backbone, outfeatx_size = loadResnetBackbone(arch=featx_arch, + torch_pretrain=pretrained) + self.projection_head = MoCoProjectionHead(outfeatx_size, 2048, 128) + + self.backbone_momentum = copy.deepcopy(self.backbone) + self.projection_head_momentum = copy.deepcopy(self.projection_head) + + deactivate_requires_grad(self.backbone_momentum.parameters()) + deactivate_requires_grad(self.projection_head_momentum.parameters()) + + def forward(self, x): + query = self.backbone(x).flatten(start_dim=1) + query = self.projection_head(query) + return query + + def forward_momentum(self, x): + key = self.backbone_momentum(x).flatten(start_dim=1) + key = self.projection_head_momentum(key).detach() + return key \ No newline at end of file