--- a +++ b/algorithms/simclr.py @@ -0,0 +1,169 @@ +import math +import torch +from torch import nn, optim + +from algorithms.arch.resnet import loadResnetBackbone +import utilities.runUtils as rutl + +def device_as(t1, t2): + """ + Moves t1 to the device of t2 + """ + return t1.to(t2.device) + +##==================== Model =============================================== + +class ContrastiveLoss(nn.Module): + """ + Vanilla Contrastive loss, also called InfoNceLoss as in SimCLR paper + """ + def __init__(self, batch_size, temperature=0.5): + super().__init__() + self.batch_size = batch_size + self.temperature = temperature + self.mask = (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool)).float() + + def calc_similarity_batch(self, a, b): + representations = torch.cat([a, b], dim=0) + return nn.functional.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2) + + def forward(self, proj_1, proj_2): + """ + proj_1 and proj_2 are batched embeddings [batch, embedding_dim] + where corresponding indices are pairs + z_i, z_j in the SimCLR paper + """ + batch_size = proj_1.shape[0] + z_i = nn.functional.normalize(proj_1, p=2, dim=1) + z_j = nn.functional.normalize(proj_2, p=2, dim=1) + + similarity_matrix = self.calc_similarity_batch(z_i, z_j) + + sim_ij = torch.diag(similarity_matrix, batch_size) + sim_ji = torch.diag(similarity_matrix, -batch_size) + + positives = torch.cat([sim_ij, sim_ji], dim=0) + + nominator = torch.exp(positives / self.temperature) + + denominator = device_as(self.mask, similarity_matrix) * torch.exp(similarity_matrix / self.temperature) + + all_losses = -torch.log(nominator / torch.sum(denominator, dim=1)) + loss = torch.sum(all_losses) / (2 * self.batch_size) + return loss + +class SimCLR(nn.Module): + def __init__(self, featx_arch, projector_sizes, + batch_size, temp, pretrained=None): + super().__init__() + rutl.START_SEED() + + mlp_dim = projector_sizes[0] + embedding_size = projector_sizes[1] + + self.batch_size = batch_size + self.temp = temp + self.backbone, outfeatx_size = loadResnetBackbone(arch=featx_arch, + torch_pretrain=pretrained) + # add mlp projection head + self.projector = nn.Sequential( + nn.Linear(in_features=outfeatx_size, out_features=mlp_dim), + nn.BatchNorm1d(mlp_dim), + nn.ReLU(), + nn.Linear(in_features=mlp_dim, out_features=embedding_size), + # nn.BatchNorm1d(embedding_size), + ) + def forward(self, y1, y2): + z1 = self.projector(self.backbone(y1)) + z2 = self.projector(self.backbone(y2)) + loss = ContrastiveLoss(self.batch_size, self.temp) + return loss(z1, z2) + +##==================== OPTIMISER =============================================== + +class LARS(optim.Optimizer): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + trust_coefficient=0.001, + eps=1e-8, + ): + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + trust_coefficient=trust_coefficient, + eps=eps, + ) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + + for group in self.param_groups: + group.setdefault("nesterov", False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # exclude scaling for params with 0 weight decay + for group in self.param_groups: + weight_decay = group["weight_decay"] + momentum = group["momentum"] + dampening = group["dampening"] + nesterov = group["nesterov"] + + for p in group["params"]: + if p.grad is None: + continue + + d_p = p.grad + p_norm = torch.norm(p.data) + g_norm = torch.norm(p.grad.data) + + # lars scaling + weight decay part + if weight_decay != 0: + if p_norm != 0 and g_norm != 0: + lars_lr = p_norm / (g_norm + p_norm * weight_decay + group["eps"]) + lars_lr *= group["trust_coefficient"] + + d_p = d_p.add(p, alpha=weight_decay) + d_p *= lars_lr + + # sgd part + if momentum != 0: + param_state = self.state[p] + if "momentum_buffer" not in param_state: + buf = param_state["momentum_buffer"] = torch.clone(d_p).detach() + else: + buf = param_state["momentum_buffer"] + buf.mul_(momentum).add_(d_p, alpha=1 - dampening) + if nesterov: + d_p = d_p.add(buf, alpha=momentum) + else: + d_p = buf + + p.add_(d_p, alpha=-group["lr"]) + + return loss \ No newline at end of file