--- a
+++ b/algorithms/loss/ssl_losses.py
@@ -0,0 +1,323 @@
+""" Negative Cosine Similarity Loss Function """
+
+# Copyright (c) 2020. Lightly AG and its affiliates.
+# All Rights Reserved
+
+import torch
+from torch.nn.functional import cosine_similarity
+
+
+class NegativeCosineSimilarity(torch.nn.Module):
+    """Implementation of the Negative Cosine Simililarity used in the SimSiam[0] paper.
+    [0] SimSiam, 2020, https://arxiv.org/abs/2011.10566
+    Examples:
+        >>> # initialize loss function
+        >>> loss_fn = NegativeCosineSimilarity()
+        >>>
+        >>> # generate two representation tensors
+        >>> # with batch size 10 and dimension 128
+        >>> x0 = torch.randn(10, 128)
+        >>> x1 = torch.randn(10, 128)
+        >>>
+        >>> # calculate loss
+        >>> loss = loss_fn(x0, x1)
+    """
+
+    def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
+        """Same parameters as in torch.nn.CosineSimilarity
+        Args:
+            dim (int, optional):
+                Dimension where cosine similarity is computed. Default: 1
+            eps (float, optional):
+                Small value to avoid division by zero. Default: 1e-8
+        """
+        super().__init__()
+        self.dim = dim
+        self.eps = eps
+
+    def forward(self, x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
+        return -cosine_similarity(x0, x1, self.dim, self.eps).mean()
+
+""" Memory Bank Wrapper """
+
+# Copyright (c) 2020. Lightly AG and its affiliates.
+# All Rights Reserved
+
+import functools
+
+class MemoryBankModule(torch.nn.Module):
+    """Memory bank implementation
+
+    This is a parent class to all loss functions implemented by the lightly
+    Python package. This way, any loss can be used with a memory bank if
+    desired.
+
+    Attributes:
+        size:
+            Number of keys the memory bank can store. If set to 0,
+            memory bank is not used.
+
+    Examples:
+        >>> class MyLossFunction(MemoryBankModule):
+        >>>
+        >>>     def __init__(self, memory_bank_size: int = 2 ** 16):
+        >>>         super(MyLossFunction, self).__init__(memory_bank_size)
+        >>>
+        >>>     def forward(self, output: torch.Tensor,
+        >>>                 labels: torch.Tensor = None):
+        >>>
+        >>>         output, negatives = super(
+        >>>             MyLossFunction, self).forward(output)
+        >>>
+        >>>         if negatives is not None:
+        >>>             # evaluate loss with negative samples
+        >>>         else:
+        >>>             # evaluate loss without negative samples
+
+    """
+
+    def __init__(self, size: int = 2**16):
+        super(MemoryBankModule, self).__init__()
+
+        if size < 0:
+            msg = f"Illegal memory bank size {size}, must be non-negative."
+            raise ValueError(msg)
+
+        self.size = size
+        self.register_buffer(
+            "bank", tensor=torch.empty(0, dtype=torch.float), persistent=False
+        )
+        self.register_buffer(
+            "bank_ptr", tensor=torch.empty(0, dtype=torch.long), persistent=False
+        )
+
+    @torch.no_grad()
+    def _init_memory_bank(self, dim: int):
+        """Initialize the memory bank if it's empty
+
+        Args:
+            dim:
+                The dimension of the which are stored in the bank.
+
+        """
+        # create memory bank
+        # we could use register buffers like in the moco repo
+        # https://github.com/facebookresearch/moco but we don't
+        # want to pollute our checkpoints
+        self.bank = torch.randn(dim, self.size).type_as(self.bank)
+        self.bank = torch.nn.functional.normalize(self.bank, dim=0)
+        self.bank_ptr = torch.zeros(1).type_as(self.bank_ptr)
+
+    @torch.no_grad()
+    def _dequeue_and_enqueue(self, batch: torch.Tensor):
+        """Dequeue the oldest batch and add the latest one
+
+        Args:
+            batch:
+                The latest batch of keys to add to the memory bank.
+
+        """
+        batch_size = batch.shape[0]
+        ptr = int(self.bank_ptr)
+
+        if ptr + batch_size >= self.size:
+            self.bank[:, ptr:] = batch[: self.size - ptr].T.detach()
+            self.bank_ptr[0] = 0
+        else:
+            self.bank[:, ptr : ptr + batch_size] = batch.T.detach()
+            self.bank_ptr[0] = ptr + batch_size
+
+    def forward(
+        self, output: torch.Tensor, labels: torch.Tensor = None, update: bool = False
+    ):
+        """Query memory bank for additional negative samples
+
+        Args:
+            output:
+                The output of the model.
+            labels:
+                Should always be None, will be ignored.
+
+        Returns:
+            The output if the memory bank is of size 0, otherwise the output
+            and the entries from the memory bank.
+
+        """
+
+        # no memory bank, return the output
+        if self.size == 0:
+            return output, None
+
+        _, dim = output.shape
+
+        # initialize the memory bank if it is not already done
+        if self.bank.nelement() == 0:
+            self._init_memory_bank(dim)
+
+        # query and update memory bank
+        bank = self.bank.clone().detach()
+
+        # only update memory bank if we later do backward pass (gradient)
+        if update:
+            self._dequeue_and_enqueue(output)
+
+        return output, bank
+    
+
+""" Contrastive Loss Functions """
+
+# Copyright (c) 2020. Lightly AG and its affiliates.
+# All Rights Reserved
+
+# from torch import distributed as torch_dist
+# from torch import nn
+
+class NTXentLoss(MemoryBankModule):
+    """Implementation of the Contrastive Cross Entropy Loss.
+
+    This implementation follows the SimCLR[0] paper. If you enable the memory
+    bank by setting the `memory_bank_size` value > 0 the loss behaves like
+    the one described in the MoCo[1] paper.
+
+    - [0] SimCLR, 2020, https://arxiv.org/abs/2002.05709
+    - [1] MoCo, 2020, https://arxiv.org/abs/1911.05722
+
+    Attributes:
+        temperature:
+            Scale logits by the inverse of the temperature.
+        memory_bank_size:
+            Number of negative samples to store in the memory bank.
+            Use 0 for SimCLR. For MoCo we typically use numbers like 4096 or 65536.
+        gather_distributed:
+            If True then negatives from all gpus are gathered before the
+            loss calculation. This flag has no effect if memory_bank_size > 0.
+
+    Raises:
+        ValueError: If abs(temperature) < 1e-8 to prevent divide by zero.
+
+    Examples:
+
+        >>> # initialize loss function without memory bank
+        >>> loss_fn = NTXentLoss(memory_bank_size=0)
+        >>>
+        >>> # generate two random transforms of images
+        >>> t0 = transforms(images)
+        >>> t1 = transforms(images)
+        >>>
+        >>> # feed through SimCLR or MoCo model
+        >>> batch = torch.cat((t0, t1), dim=0)
+        >>> output = model(batch)
+        >>>
+        >>> # calculate loss
+        >>> loss = loss_fn(output)
+
+    """
+
+    def __init__(
+        self,
+        temperature: float = 0.5,
+        memory_bank_size: int = 4096,
+    ):
+        super(NTXentLoss, self).__init__(size=memory_bank_size)
+        self.temperature = temperature
+        self.cross_entropy = torch.nn.CrossEntropyLoss(reduction="mean")
+        self.eps = 1e-8
+
+        if abs(self.temperature) < self.eps:
+            raise ValueError(
+                "Illegal temperature: abs({}) < 1e-8".format(self.temperature)
+            )
+
+    def forward(self, out0: torch.Tensor, out1: torch.Tensor):
+        """Forward pass through Contrastive Cross-Entropy Loss.
+
+        If used with a memory bank, the samples from the memory bank are used
+        as negative examples. Otherwise, within-batch samples are used as
+        negative samples.
+
+        Args:
+            out0:
+                Output projections of the first set of transformed images.
+                Shape: (batch_size, embedding_size)
+            out1:
+                Output projections of the second set of transformed images.
+                Shape: (batch_size, embedding_size)
+
+        Returns:
+            Contrastive Cross Entropy Loss value.
+
+        """
+
+        device = out0.device
+        batch_size, _ = out0.shape
+
+        # normalize the output to length 1
+        out0 = torch.nn.functional.normalize(out0, dim=1)
+        out1 = torch.nn.functional.normalize(out1, dim=1)
+
+        # ask memory bank for negative samples and extend it with out1 if
+        # out1 requires a gradient, otherwise keep the same vectors in the
+        # memory bank (this allows for keeping the memory bank constant e.g.
+        # for evaluating the loss on the test set)
+        # out1: shape: (batch_size, embedding_size)
+        # negatives: shape: (embedding_size, memory_bank_size)
+        out1, negatives = super(NTXentLoss, self).forward(
+            out1, update=out0.requires_grad
+        )
+
+        # We use the cosine similarity, which is a dot product (einsum) here,
+        # as all vectors are already normalized to unit length.
+        # Notation in einsum: n = batch_size, c = embedding_size and k = memory_bank_size.
+
+        if negatives is not None:
+            # use negatives from memory bank
+            negatives = negatives.to(device)
+
+            # sim_pos is of shape (batch_size, 1) and sim_pos[i] denotes the similarity
+            # of the i-th sample in the batch to its positive pair
+            sim_pos = torch.einsum("nc,nc->n", out0, out1).unsqueeze(-1)
+
+            # sim_neg is of shape (batch_size, memory_bank_size) and sim_neg[i,j] denotes the similarity
+            # of the i-th sample to the j-th negative sample
+            sim_neg = torch.einsum("nc,ck->nk", out0, negatives)
+
+            # set the labels to the first "class", i.e. sim_pos,
+            # so that it is maximized in relation to sim_neg
+            logits = torch.cat([sim_pos, sim_neg], dim=1) / self.temperature
+            labels = torch.zeros(logits.shape[0], device=device, dtype=torch.long)
+
+        else:
+            # user other samples from batch as negatives
+            # and create diagonal mask that only selects similarities between
+            # views of the same image
+            
+            # single process
+            out0_large = out0
+            out1_large = out1
+            diag_mask = torch.eye(batch_size, device=out0.device, dtype=torch.bool)
+
+            # calculate similiarities
+            # here n = batch_size and m = batch_size * world_size
+            # the resulting vectors have shape (n, m)
+            logits_00 = torch.einsum("nc,mc->nm", out0, out0_large) / self.temperature
+            logits_01 = torch.einsum("nc,mc->nm", out0, out1_large) / self.temperature
+            logits_10 = torch.einsum("nc,mc->nm", out1, out0_large) / self.temperature
+            logits_11 = torch.einsum("nc,mc->nm", out1, out1_large) / self.temperature
+
+            # remove simliarities between same views of the same image
+            logits_00 = logits_00[~diag_mask].view(batch_size, -1)
+            logits_11 = logits_11[~diag_mask].view(batch_size, -1)
+
+            # concatenate logits
+            # the logits tensor in the end has shape (2*n, 2*m-1)
+            logits_0100 = torch.cat([logits_01, logits_00], dim=1)
+            logits_1011 = torch.cat([logits_10, logits_11], dim=1)
+            logits = torch.cat([logits_0100, logits_1011], dim=0)
+
+            # create labels
+            labels = torch.arange(batch_size, device=device, dtype=torch.long)
+            labels = labels.repeat(2)
+
+        loss = self.cross_entropy(logits, labels)
+
+        return loss
\ No newline at end of file