import math
import torch
from torch import nn, optim
from algorithms.arch.resnet import loadResnetBackbone
import utilities.runUtils as rutl
def device_as(t1, t2):
"""
Moves t1 to the device of t2
"""
return t1.to(t2.device)
##==================== Model ===============================================
class ContrastiveLoss(nn.Module):
"""
Vanilla Contrastive loss, also called InfoNceLoss as in SimCLR paper
"""
def __init__(self, batch_size, temperature=0.5):
super().__init__()
self.batch_size = batch_size
self.temperature = temperature
self.mask = (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool)).float()
def calc_similarity_batch(self, a, b):
representations = torch.cat([a, b], dim=0)
return nn.functional.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
def forward(self, proj_1, proj_2):
"""
proj_1 and proj_2 are batched embeddings [batch, embedding_dim]
where corresponding indices are pairs
z_i, z_j in the SimCLR paper
"""
batch_size = proj_1.shape[0]
z_i = nn.functional.normalize(proj_1, p=2, dim=1)
z_j = nn.functional.normalize(proj_2, p=2, dim=1)
similarity_matrix = self.calc_similarity_batch(z_i, z_j)
sim_ij = torch.diag(similarity_matrix, batch_size)
sim_ji = torch.diag(similarity_matrix, -batch_size)
positives = torch.cat([sim_ij, sim_ji], dim=0)
nominator = torch.exp(positives / self.temperature)
denominator = device_as(self.mask, similarity_matrix) * torch.exp(similarity_matrix / self.temperature)
all_losses = -torch.log(nominator / torch.sum(denominator, dim=1))
loss = torch.sum(all_losses) / (2 * self.batch_size)
return loss
class SimCLR(nn.Module):
def __init__(self, featx_arch, projector_sizes,
batch_size, temp, pretrained=None):
super().__init__()
rutl.START_SEED()
mlp_dim = projector_sizes[0]
embedding_size = projector_sizes[1]
self.batch_size = batch_size
self.temp = temp
self.backbone, outfeatx_size = loadResnetBackbone(arch=featx_arch,
torch_pretrain=pretrained)
# add mlp projection head
self.projector = nn.Sequential(
nn.Linear(in_features=outfeatx_size, out_features=mlp_dim),
nn.BatchNorm1d(mlp_dim),
nn.ReLU(),
nn.Linear(in_features=mlp_dim, out_features=embedding_size),
# nn.BatchNorm1d(embedding_size),
)
def forward(self, y1, y2):
z1 = self.projector(self.backbone(y1))
z2 = self.projector(self.backbone(y2))
loss = ContrastiveLoss(self.batch_size, self.temp)
return loss(z1, z2)
##==================== OPTIMISER ===============================================
class LARS(optim.Optimizer):
def __init__(
self,
params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
trust_coefficient=0.001,
eps=1e-8,
):
defaults = dict(
lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
trust_coefficient=trust_coefficient,
eps=eps,
)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("nesterov", False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
# exclude scaling for params with 0 weight decay
for group in self.param_groups:
weight_decay = group["weight_decay"]
momentum = group["momentum"]
dampening = group["dampening"]
nesterov = group["nesterov"]
for p in group["params"]:
if p.grad is None:
continue
d_p = p.grad
p_norm = torch.norm(p.data)
g_norm = torch.norm(p.grad.data)
# lars scaling + weight decay part
if weight_decay != 0:
if p_norm != 0 and g_norm != 0:
lars_lr = p_norm / (g_norm + p_norm * weight_decay + group["eps"])
lars_lr *= group["trust_coefficient"]
d_p = d_p.add(p, alpha=weight_decay)
d_p *= lars_lr
# sgd part
if momentum != 0:
param_state = self.state[p]
if "momentum_buffer" not in param_state:
buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
else:
buf = param_state["momentum_buffer"]
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
if nesterov:
d_p = d_p.add(buf, alpha=momentum)
else:
d_p = buf
p.add_(d_p, alpha=-group["lr"])
return loss