[a18f15]: / algorithms / moco.py

Download this file

108 lines (88 with data), 3.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
import torch
import copy
from torch import nn, optim
from typing import List, Optional, Tuple, Union
from algorithms.arch.resnet import loadResnetBackbone
import utilities.runUtils as rutl
def device_as(t1, t2):
"""
Moves t1 to the device of t2
"""
return t1.to(t2.device)
def deactivate_requires_grad(params):
"""Deactivates the requires_grad flag for all parameters.
"""
for param in params:
param.requires_grad = False
##==================== Model ===============================================
class ProjectionHead(nn.Module):
"""Base class for all projection and prediction heads.
Args:
blocks:
List of tuples, each denoting one block of the projection head MLP.
Each tuple reads (in_features, out_features, batch_norm_layer,
non_linearity_layer).
Examples:
>>> # the following projection head has two blocks
>>> # the first block uses batch norm an a ReLU non-linearity
>>> # the second block is a simple linear layer
>>> projection_head = ProjectionHead([
>>> (256, 256, nn.BatchNorm1d(256), nn.ReLU()),
>>> (256, 128, None, None)
>>> ])
"""
def __init__(
self, blocks: List[Tuple[int, int, Optional[nn.Module], Optional[nn.Module]]]
):
super(ProjectionHead, self).__init__()
layers = []
for input_dim, output_dim, batch_norm, non_linearity in blocks:
use_bias = not bool(batch_norm)
layers.append(nn.Linear(input_dim, output_dim, bias=use_bias))
if batch_norm:
layers.append(batch_norm)
if non_linearity:
layers.append(non_linearity)
self.layers = nn.Sequential(*layers)
def forward(self, x: torch.Tensor):
"""Computes one forward pass through the projection head.
Args:
x:
Input of shape bsz x num_ftrs.
"""
return self.layers(x)
class MoCoProjectionHead(ProjectionHead):
"""Projection head used for MoCo.
"(...) we replace the fc head in MoCo with a 2-layer MLP head (hidden layer
2048-d, with ReLU)" [0]
[0]: MoCo, 2020, https://arxiv.org/abs/1911.05722
"""
def __init__(self,
input_dim: int = 2048,
hidden_dim: int = 2048,
output_dim: int = 128):
super(MoCoProjectionHead, self).__init__([
(input_dim, hidden_dim, None, nn.ReLU()),
(hidden_dim, output_dim, None, None),
])
class MoCo(nn.Module):
def __init__(self, featx_arch, pretrained=None, backbone=None):
super().__init__()
if backbone is not None:
self.backbone = backbone
else:
self.backbone, outfeatx_size = loadResnetBackbone(arch=featx_arch,
torch_pretrain=pretrained)
self.projection_head = MoCoProjectionHead(outfeatx_size, 2048, 128)
self.backbone_momentum = copy.deepcopy(self.backbone)
self.projection_head_momentum = copy.deepcopy(self.projection_head)
deactivate_requires_grad(self.backbone_momentum.parameters())
deactivate_requires_grad(self.projection_head_momentum.parameters())
def forward(self, x):
query = self.backbone(x).flatten(start_dim=1)
query = self.projection_head(query)
return query
def forward_momentum(self, x):
key = self.backbone_momentum(x).flatten(start_dim=1)
key = self.projection_head_momentum(key).detach()
return key