import torch
import copy
from torch import nn, optim
from typing import List, Optional, Tuple, Union
from algorithms.arch.resnet import loadResnetBackbone
import utilities.runUtils as rutl
def device_as(t1, t2):
"""
Moves t1 to the device of t2
"""
return t1.to(t2.device)
##==================== Model ===============================================
class ProjectionHead(nn.Module):
"""Base class for all projection and prediction heads.
Args:
blocks:
List of tuples, each denoting one block of the projection head MLP.
Each tuple reads (in_features, out_features, batch_norm_layer,
non_linearity_layer).
Examples:
>>> # the following projection head has two blocks
>>> # the first block uses batch norm an a ReLU non-linearity
>>> # the second block is a simple linear layer
>>> projection_head = ProjectionHead([
>>> (256, 256, nn.BatchNorm1d(256), nn.ReLU()),
>>> (256, 128, None, None)
>>> ])
"""
def __init__(
self, blocks: List[Tuple[int, int, Optional[nn.Module], Optional[nn.Module]]]
):
super(ProjectionHead, self).__init__()
layers = []
for input_dim, output_dim, batch_norm, non_linearity in blocks:
use_bias = not bool(batch_norm)
layers.append(nn.Linear(input_dim, output_dim, bias=use_bias))
if batch_norm:
layers.append(batch_norm)
if non_linearity:
layers.append(non_linearity)
self.layers = nn.Sequential(*layers)
def forward(self, x: torch.Tensor):
"""Computes one forward pass through the projection head.
Args:
x:
Input of shape bsz x num_ftrs.
"""
return self.layers(x)
class BYOLProjectionHead(ProjectionHead):
"""Projection head used for BYOL.
"This MLP consists in a linear layer with output size 4096 followed by
batch normalization, rectified linear units (ReLU), and a final
linear layer with output dimension 256." [0]
[0]: BYOL, 2020, https://arxiv.org/abs/2006.07733
"""
def __init__(
self, input_dim: int = 2048, hidden_dim: int = 4096, output_dim: int = 256
):
super(BYOLProjectionHead, self).__init__(
[
(input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()),
(hidden_dim, output_dim, None, None),
]
)
class BYOLPredictionHead(ProjectionHead):
"""Prediction head used for BYOL.
"This MLP consists in a linear layer with output size 4096 followed by
batch normalization, rectified linear units (ReLU), and a final
linear layer with output dimension 256." [0]
[0]: BYOL, 2020, https://arxiv.org/abs/2006.07733
"""
def __init__(
self, input_dim: int = 256, hidden_dim: int = 4096, output_dim: int = 256
):
super(BYOLPredictionHead, self).__init__(
[
(input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()),
(hidden_dim, output_dim, None, None),
]
)
def deactivate_requires_grad(model: nn.Module):
"""Deactivates the requires_grad flag for all parameters of a model.
This has the same effect as permanently executing the model within a `torch.no_grad()`
context. Use this method to disable gradient computation and therefore
training for a model.
Examples:
>>> backbone = resnet18()
>>> deactivate_requires_grad(backbone)
"""
for param in model.parameters():
param.requires_grad = False
class BYOL(nn.Module):
def __init__(self, featx_arch, pretrained=None, backbone=None):
super().__init__()
if backbone is not None:
self.backbone = backbone
else:
self.backbone, outfeatx_size = loadResnetBackbone(arch=featx_arch,
torch_pretrain=pretrained)
self.projection_head = BYOLProjectionHead(outfeatx_size, 1024, 256)
self.prediction_head = BYOLPredictionHead(256, 1024, 256)
self.backbone_momentum = copy.deepcopy(self.backbone)
self.projection_head_momentum = copy.deepcopy(self.projection_head)
deactivate_requires_grad(self.backbone_momentum)
deactivate_requires_grad(self.projection_head_momentum)
def forward(self, x):
y = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(y)
p = self.prediction_head(z)
return p
def forward_momentum(self, x):
y = self.backbone_momentum(x).flatten(start_dim=1)
z = self.projection_head_momentum(y)
z = z.detach()
return z