[a18f15]: / algorithms / byol.py

Download this file

136 lines (115 with data), 4.7 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
import copy
from torch import nn, optim
from typing import List, Optional, Tuple, Union
from algorithms.arch.resnet import loadResnetBackbone
import utilities.runUtils as rutl
def device_as(t1, t2):
"""
Moves t1 to the device of t2
"""
return t1.to(t2.device)
##==================== Model ===============================================
class ProjectionHead(nn.Module):
"""Base class for all projection and prediction heads.
Args:
blocks:
List of tuples, each denoting one block of the projection head MLP.
Each tuple reads (in_features, out_features, batch_norm_layer,
non_linearity_layer).
Examples:
>>> # the following projection head has two blocks
>>> # the first block uses batch norm an a ReLU non-linearity
>>> # the second block is a simple linear layer
>>> projection_head = ProjectionHead([
>>> (256, 256, nn.BatchNorm1d(256), nn.ReLU()),
>>> (256, 128, None, None)
>>> ])
"""
def __init__(
self, blocks: List[Tuple[int, int, Optional[nn.Module], Optional[nn.Module]]]
):
super(ProjectionHead, self).__init__()
layers = []
for input_dim, output_dim, batch_norm, non_linearity in blocks:
use_bias = not bool(batch_norm)
layers.append(nn.Linear(input_dim, output_dim, bias=use_bias))
if batch_norm:
layers.append(batch_norm)
if non_linearity:
layers.append(non_linearity)
self.layers = nn.Sequential(*layers)
def forward(self, x: torch.Tensor):
"""Computes one forward pass through the projection head.
Args:
x:
Input of shape bsz x num_ftrs.
"""
return self.layers(x)
class BYOLProjectionHead(ProjectionHead):
"""Projection head used for BYOL.
"This MLP consists in a linear layer with output size 4096 followed by
batch normalization, rectified linear units (ReLU), and a final
linear layer with output dimension 256." [0]
[0]: BYOL, 2020, https://arxiv.org/abs/2006.07733
"""
def __init__(
self, input_dim: int = 2048, hidden_dim: int = 4096, output_dim: int = 256
):
super(BYOLProjectionHead, self).__init__(
[
(input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()),
(hidden_dim, output_dim, None, None),
]
)
class BYOLPredictionHead(ProjectionHead):
"""Prediction head used for BYOL.
"This MLP consists in a linear layer with output size 4096 followed by
batch normalization, rectified linear units (ReLU), and a final
linear layer with output dimension 256." [0]
[0]: BYOL, 2020, https://arxiv.org/abs/2006.07733
"""
def __init__(
self, input_dim: int = 256, hidden_dim: int = 4096, output_dim: int = 256
):
super(BYOLPredictionHead, self).__init__(
[
(input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()),
(hidden_dim, output_dim, None, None),
]
)
def deactivate_requires_grad(model: nn.Module):
"""Deactivates the requires_grad flag for all parameters of a model.
This has the same effect as permanently executing the model within a `torch.no_grad()`
context. Use this method to disable gradient computation and therefore
training for a model.
Examples:
>>> backbone = resnet18()
>>> deactivate_requires_grad(backbone)
"""
for param in model.parameters():
param.requires_grad = False
class BYOL(nn.Module):
def __init__(self, featx_arch, pretrained=None, backbone=None):
super().__init__()
if backbone is not None:
self.backbone = backbone
else:
self.backbone, outfeatx_size = loadResnetBackbone(arch=featx_arch,
torch_pretrain=pretrained)
self.projection_head = BYOLProjectionHead(outfeatx_size, 1024, 256)
self.prediction_head = BYOLPredictionHead(256, 1024, 256)
self.backbone_momentum = copy.deepcopy(self.backbone)
self.projection_head_momentum = copy.deepcopy(self.projection_head)
deactivate_requires_grad(self.backbone_momentum)
deactivate_requires_grad(self.projection_head_momentum)
def forward(self, x):
y = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(y)
p = self.prediction_head(z)
return p
def forward_momentum(self, x):
y = self.backbone_momentum(x).flatten(start_dim=1)
z = self.projection_head_momentum(y)
z = z.detach()
return z