""" Negative Cosine Similarity Loss Function """
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
import torch
from torch.nn.functional import cosine_similarity
class NegativeCosineSimilarity(torch.nn.Module):
"""Implementation of the Negative Cosine Simililarity used in the SimSiam[0] paper.
[0] SimSiam, 2020, https://arxiv.org/abs/2011.10566
Examples:
>>> # initialize loss function
>>> loss_fn = NegativeCosineSimilarity()
>>>
>>> # generate two representation tensors
>>> # with batch size 10 and dimension 128
>>> x0 = torch.randn(10, 128)
>>> x1 = torch.randn(10, 128)
>>>
>>> # calculate loss
>>> loss = loss_fn(x0, x1)
"""
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
"""Same parameters as in torch.nn.CosineSimilarity
Args:
dim (int, optional):
Dimension where cosine similarity is computed. Default: 1
eps (float, optional):
Small value to avoid division by zero. Default: 1e-8
"""
super().__init__()
self.dim = dim
self.eps = eps
def forward(self, x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
return -cosine_similarity(x0, x1, self.dim, self.eps).mean()
""" Memory Bank Wrapper """
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
import functools
class MemoryBankModule(torch.nn.Module):
"""Memory bank implementation
This is a parent class to all loss functions implemented by the lightly
Python package. This way, any loss can be used with a memory bank if
desired.
Attributes:
size:
Number of keys the memory bank can store. If set to 0,
memory bank is not used.
Examples:
>>> class MyLossFunction(MemoryBankModule):
>>>
>>> def __init__(self, memory_bank_size: int = 2 ** 16):
>>> super(MyLossFunction, self).__init__(memory_bank_size)
>>>
>>> def forward(self, output: torch.Tensor,
>>> labels: torch.Tensor = None):
>>>
>>> output, negatives = super(
>>> MyLossFunction, self).forward(output)
>>>
>>> if negatives is not None:
>>> # evaluate loss with negative samples
>>> else:
>>> # evaluate loss without negative samples
"""
def __init__(self, size: int = 2**16):
super(MemoryBankModule, self).__init__()
if size < 0:
msg = f"Illegal memory bank size {size}, must be non-negative."
raise ValueError(msg)
self.size = size
self.register_buffer(
"bank", tensor=torch.empty(0, dtype=torch.float), persistent=False
)
self.register_buffer(
"bank_ptr", tensor=torch.empty(0, dtype=torch.long), persistent=False
)
@torch.no_grad()
def _init_memory_bank(self, dim: int):
"""Initialize the memory bank if it's empty
Args:
dim:
The dimension of the which are stored in the bank.
"""
# create memory bank
# we could use register buffers like in the moco repo
# https://github.com/facebookresearch/moco but we don't
# want to pollute our checkpoints
self.bank = torch.randn(dim, self.size).type_as(self.bank)
self.bank = torch.nn.functional.normalize(self.bank, dim=0)
self.bank_ptr = torch.zeros(1).type_as(self.bank_ptr)
@torch.no_grad()
def _dequeue_and_enqueue(self, batch: torch.Tensor):
"""Dequeue the oldest batch and add the latest one
Args:
batch:
The latest batch of keys to add to the memory bank.
"""
batch_size = batch.shape[0]
ptr = int(self.bank_ptr)
if ptr + batch_size >= self.size:
self.bank[:, ptr:] = batch[: self.size - ptr].T.detach()
self.bank_ptr[0] = 0
else:
self.bank[:, ptr : ptr + batch_size] = batch.T.detach()
self.bank_ptr[0] = ptr + batch_size
def forward(
self, output: torch.Tensor, labels: torch.Tensor = None, update: bool = False
):
"""Query memory bank for additional negative samples
Args:
output:
The output of the model.
labels:
Should always be None, will be ignored.
Returns:
The output if the memory bank is of size 0, otherwise the output
and the entries from the memory bank.
"""
# no memory bank, return the output
if self.size == 0:
return output, None
_, dim = output.shape
# initialize the memory bank if it is not already done
if self.bank.nelement() == 0:
self._init_memory_bank(dim)
# query and update memory bank
bank = self.bank.clone().detach()
# only update memory bank if we later do backward pass (gradient)
if update:
self._dequeue_and_enqueue(output)
return output, bank
""" Contrastive Loss Functions """
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
# from torch import distributed as torch_dist
# from torch import nn
class NTXentLoss(MemoryBankModule):
"""Implementation of the Contrastive Cross Entropy Loss.
This implementation follows the SimCLR[0] paper. If you enable the memory
bank by setting the `memory_bank_size` value > 0 the loss behaves like
the one described in the MoCo[1] paper.
- [0] SimCLR, 2020, https://arxiv.org/abs/2002.05709
- [1] MoCo, 2020, https://arxiv.org/abs/1911.05722
Attributes:
temperature:
Scale logits by the inverse of the temperature.
memory_bank_size:
Number of negative samples to store in the memory bank.
Use 0 for SimCLR. For MoCo we typically use numbers like 4096 or 65536.
gather_distributed:
If True then negatives from all gpus are gathered before the
loss calculation. This flag has no effect if memory_bank_size > 0.
Raises:
ValueError: If abs(temperature) < 1e-8 to prevent divide by zero.
Examples:
>>> # initialize loss function without memory bank
>>> loss_fn = NTXentLoss(memory_bank_size=0)
>>>
>>> # generate two random transforms of images
>>> t0 = transforms(images)
>>> t1 = transforms(images)
>>>
>>> # feed through SimCLR or MoCo model
>>> batch = torch.cat((t0, t1), dim=0)
>>> output = model(batch)
>>>
>>> # calculate loss
>>> loss = loss_fn(output)
"""
def __init__(
self,
temperature: float = 0.5,
memory_bank_size: int = 4096,
):
super(NTXentLoss, self).__init__(size=memory_bank_size)
self.temperature = temperature
self.cross_entropy = torch.nn.CrossEntropyLoss(reduction="mean")
self.eps = 1e-8
if abs(self.temperature) < self.eps:
raise ValueError(
"Illegal temperature: abs({}) < 1e-8".format(self.temperature)
)
def forward(self, out0: torch.Tensor, out1: torch.Tensor):
"""Forward pass through Contrastive Cross-Entropy Loss.
If used with a memory bank, the samples from the memory bank are used
as negative examples. Otherwise, within-batch samples are used as
negative samples.
Args:
out0:
Output projections of the first set of transformed images.
Shape: (batch_size, embedding_size)
out1:
Output projections of the second set of transformed images.
Shape: (batch_size, embedding_size)
Returns:
Contrastive Cross Entropy Loss value.
"""
device = out0.device
batch_size, _ = out0.shape
# normalize the output to length 1
out0 = torch.nn.functional.normalize(out0, dim=1)
out1 = torch.nn.functional.normalize(out1, dim=1)
# ask memory bank for negative samples and extend it with out1 if
# out1 requires a gradient, otherwise keep the same vectors in the
# memory bank (this allows for keeping the memory bank constant e.g.
# for evaluating the loss on the test set)
# out1: shape: (batch_size, embedding_size)
# negatives: shape: (embedding_size, memory_bank_size)
out1, negatives = super(NTXentLoss, self).forward(
out1, update=out0.requires_grad
)
# We use the cosine similarity, which is a dot product (einsum) here,
# as all vectors are already normalized to unit length.
# Notation in einsum: n = batch_size, c = embedding_size and k = memory_bank_size.
if negatives is not None:
# use negatives from memory bank
negatives = negatives.to(device)
# sim_pos is of shape (batch_size, 1) and sim_pos[i] denotes the similarity
# of the i-th sample in the batch to its positive pair
sim_pos = torch.einsum("nc,nc->n", out0, out1).unsqueeze(-1)
# sim_neg is of shape (batch_size, memory_bank_size) and sim_neg[i,j] denotes the similarity
# of the i-th sample to the j-th negative sample
sim_neg = torch.einsum("nc,ck->nk", out0, negatives)
# set the labels to the first "class", i.e. sim_pos,
# so that it is maximized in relation to sim_neg
logits = torch.cat([sim_pos, sim_neg], dim=1) / self.temperature
labels = torch.zeros(logits.shape[0], device=device, dtype=torch.long)
else:
# user other samples from batch as negatives
# and create diagonal mask that only selects similarities between
# views of the same image
# single process
out0_large = out0
out1_large = out1
diag_mask = torch.eye(batch_size, device=out0.device, dtype=torch.bool)
# calculate similiarities
# here n = batch_size and m = batch_size * world_size
# the resulting vectors have shape (n, m)
logits_00 = torch.einsum("nc,mc->nm", out0, out0_large) / self.temperature
logits_01 = torch.einsum("nc,mc->nm", out0, out1_large) / self.temperature
logits_10 = torch.einsum("nc,mc->nm", out1, out0_large) / self.temperature
logits_11 = torch.einsum("nc,mc->nm", out1, out1_large) / self.temperature
# remove simliarities between same views of the same image
logits_00 = logits_00[~diag_mask].view(batch_size, -1)
logits_11 = logits_11[~diag_mask].view(batch_size, -1)
# concatenate logits
# the logits tensor in the end has shape (2*n, 2*m-1)
logits_0100 = torch.cat([logits_01, logits_00], dim=1)
logits_1011 = torch.cat([logits_10, logits_11], dim=1)
logits = torch.cat([logits_0100, logits_1011], dim=0)
# create labels
labels = torch.arange(batch_size, device=device, dtype=torch.long)
labels = labels.repeat(2)
loss = self.cross_entropy(logits, labels)
return loss