Switch to unified view

a b/algorithms/loss/ssl_losses.py
1
""" Negative Cosine Similarity Loss Function """
2
3
# Copyright (c) 2020. Lightly AG and its affiliates.
4
# All Rights Reserved
5
6
import torch
7
from torch.nn.functional import cosine_similarity
8
9
10
class NegativeCosineSimilarity(torch.nn.Module):
11
    """Implementation of the Negative Cosine Simililarity used in the SimSiam[0] paper.
12
    [0] SimSiam, 2020, https://arxiv.org/abs/2011.10566
13
    Examples:
14
        >>> # initialize loss function
15
        >>> loss_fn = NegativeCosineSimilarity()
16
        >>>
17
        >>> # generate two representation tensors
18
        >>> # with batch size 10 and dimension 128
19
        >>> x0 = torch.randn(10, 128)
20
        >>> x1 = torch.randn(10, 128)
21
        >>>
22
        >>> # calculate loss
23
        >>> loss = loss_fn(x0, x1)
24
    """
25
26
    def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
27
        """Same parameters as in torch.nn.CosineSimilarity
28
        Args:
29
            dim (int, optional):
30
                Dimension where cosine similarity is computed. Default: 1
31
            eps (float, optional):
32
                Small value to avoid division by zero. Default: 1e-8
33
        """
34
        super().__init__()
35
        self.dim = dim
36
        self.eps = eps
37
38
    def forward(self, x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
39
        return -cosine_similarity(x0, x1, self.dim, self.eps).mean()
40
41
""" Memory Bank Wrapper """
42
43
# Copyright (c) 2020. Lightly AG and its affiliates.
44
# All Rights Reserved
45
46
import functools
47
48
class MemoryBankModule(torch.nn.Module):
49
    """Memory bank implementation
50
51
    This is a parent class to all loss functions implemented by the lightly
52
    Python package. This way, any loss can be used with a memory bank if
53
    desired.
54
55
    Attributes:
56
        size:
57
            Number of keys the memory bank can store. If set to 0,
58
            memory bank is not used.
59
60
    Examples:
61
        >>> class MyLossFunction(MemoryBankModule):
62
        >>>
63
        >>>     def __init__(self, memory_bank_size: int = 2 ** 16):
64
        >>>         super(MyLossFunction, self).__init__(memory_bank_size)
65
        >>>
66
        >>>     def forward(self, output: torch.Tensor,
67
        >>>                 labels: torch.Tensor = None):
68
        >>>
69
        >>>         output, negatives = super(
70
        >>>             MyLossFunction, self).forward(output)
71
        >>>
72
        >>>         if negatives is not None:
73
        >>>             # evaluate loss with negative samples
74
        >>>         else:
75
        >>>             # evaluate loss without negative samples
76
77
    """
78
79
    def __init__(self, size: int = 2**16):
80
        super(MemoryBankModule, self).__init__()
81
82
        if size < 0:
83
            msg = f"Illegal memory bank size {size}, must be non-negative."
84
            raise ValueError(msg)
85
86
        self.size = size
87
        self.register_buffer(
88
            "bank", tensor=torch.empty(0, dtype=torch.float), persistent=False
89
        )
90
        self.register_buffer(
91
            "bank_ptr", tensor=torch.empty(0, dtype=torch.long), persistent=False
92
        )
93
94
    @torch.no_grad()
95
    def _init_memory_bank(self, dim: int):
96
        """Initialize the memory bank if it's empty
97
98
        Args:
99
            dim:
100
                The dimension of the which are stored in the bank.
101
102
        """
103
        # create memory bank
104
        # we could use register buffers like in the moco repo
105
        # https://github.com/facebookresearch/moco but we don't
106
        # want to pollute our checkpoints
107
        self.bank = torch.randn(dim, self.size).type_as(self.bank)
108
        self.bank = torch.nn.functional.normalize(self.bank, dim=0)
109
        self.bank_ptr = torch.zeros(1).type_as(self.bank_ptr)
110
111
    @torch.no_grad()
112
    def _dequeue_and_enqueue(self, batch: torch.Tensor):
113
        """Dequeue the oldest batch and add the latest one
114
115
        Args:
116
            batch:
117
                The latest batch of keys to add to the memory bank.
118
119
        """
120
        batch_size = batch.shape[0]
121
        ptr = int(self.bank_ptr)
122
123
        if ptr + batch_size >= self.size:
124
            self.bank[:, ptr:] = batch[: self.size - ptr].T.detach()
125
            self.bank_ptr[0] = 0
126
        else:
127
            self.bank[:, ptr : ptr + batch_size] = batch.T.detach()
128
            self.bank_ptr[0] = ptr + batch_size
129
130
    def forward(
131
        self, output: torch.Tensor, labels: torch.Tensor = None, update: bool = False
132
    ):
133
        """Query memory bank for additional negative samples
134
135
        Args:
136
            output:
137
                The output of the model.
138
            labels:
139
                Should always be None, will be ignored.
140
141
        Returns:
142
            The output if the memory bank is of size 0, otherwise the output
143
            and the entries from the memory bank.
144
145
        """
146
147
        # no memory bank, return the output
148
        if self.size == 0:
149
            return output, None
150
151
        _, dim = output.shape
152
153
        # initialize the memory bank if it is not already done
154
        if self.bank.nelement() == 0:
155
            self._init_memory_bank(dim)
156
157
        # query and update memory bank
158
        bank = self.bank.clone().detach()
159
160
        # only update memory bank if we later do backward pass (gradient)
161
        if update:
162
            self._dequeue_and_enqueue(output)
163
164
        return output, bank
165
    
166
167
""" Contrastive Loss Functions """
168
169
# Copyright (c) 2020. Lightly AG and its affiliates.
170
# All Rights Reserved
171
172
# from torch import distributed as torch_dist
173
# from torch import nn
174
175
class NTXentLoss(MemoryBankModule):
176
    """Implementation of the Contrastive Cross Entropy Loss.
177
178
    This implementation follows the SimCLR[0] paper. If you enable the memory
179
    bank by setting the `memory_bank_size` value > 0 the loss behaves like
180
    the one described in the MoCo[1] paper.
181
182
    - [0] SimCLR, 2020, https://arxiv.org/abs/2002.05709
183
    - [1] MoCo, 2020, https://arxiv.org/abs/1911.05722
184
185
    Attributes:
186
        temperature:
187
            Scale logits by the inverse of the temperature.
188
        memory_bank_size:
189
            Number of negative samples to store in the memory bank.
190
            Use 0 for SimCLR. For MoCo we typically use numbers like 4096 or 65536.
191
        gather_distributed:
192
            If True then negatives from all gpus are gathered before the
193
            loss calculation. This flag has no effect if memory_bank_size > 0.
194
195
    Raises:
196
        ValueError: If abs(temperature) < 1e-8 to prevent divide by zero.
197
198
    Examples:
199
200
        >>> # initialize loss function without memory bank
201
        >>> loss_fn = NTXentLoss(memory_bank_size=0)
202
        >>>
203
        >>> # generate two random transforms of images
204
        >>> t0 = transforms(images)
205
        >>> t1 = transforms(images)
206
        >>>
207
        >>> # feed through SimCLR or MoCo model
208
        >>> batch = torch.cat((t0, t1), dim=0)
209
        >>> output = model(batch)
210
        >>>
211
        >>> # calculate loss
212
        >>> loss = loss_fn(output)
213
214
    """
215
216
    def __init__(
217
        self,
218
        temperature: float = 0.5,
219
        memory_bank_size: int = 4096,
220
    ):
221
        super(NTXentLoss, self).__init__(size=memory_bank_size)
222
        self.temperature = temperature
223
        self.cross_entropy = torch.nn.CrossEntropyLoss(reduction="mean")
224
        self.eps = 1e-8
225
226
        if abs(self.temperature) < self.eps:
227
            raise ValueError(
228
                "Illegal temperature: abs({}) < 1e-8".format(self.temperature)
229
            )
230
231
    def forward(self, out0: torch.Tensor, out1: torch.Tensor):
232
        """Forward pass through Contrastive Cross-Entropy Loss.
233
234
        If used with a memory bank, the samples from the memory bank are used
235
        as negative examples. Otherwise, within-batch samples are used as
236
        negative samples.
237
238
        Args:
239
            out0:
240
                Output projections of the first set of transformed images.
241
                Shape: (batch_size, embedding_size)
242
            out1:
243
                Output projections of the second set of transformed images.
244
                Shape: (batch_size, embedding_size)
245
246
        Returns:
247
            Contrastive Cross Entropy Loss value.
248
249
        """
250
251
        device = out0.device
252
        batch_size, _ = out0.shape
253
254
        # normalize the output to length 1
255
        out0 = torch.nn.functional.normalize(out0, dim=1)
256
        out1 = torch.nn.functional.normalize(out1, dim=1)
257
258
        # ask memory bank for negative samples and extend it with out1 if
259
        # out1 requires a gradient, otherwise keep the same vectors in the
260
        # memory bank (this allows for keeping the memory bank constant e.g.
261
        # for evaluating the loss on the test set)
262
        # out1: shape: (batch_size, embedding_size)
263
        # negatives: shape: (embedding_size, memory_bank_size)
264
        out1, negatives = super(NTXentLoss, self).forward(
265
            out1, update=out0.requires_grad
266
        )
267
268
        # We use the cosine similarity, which is a dot product (einsum) here,
269
        # as all vectors are already normalized to unit length.
270
        # Notation in einsum: n = batch_size, c = embedding_size and k = memory_bank_size.
271
272
        if negatives is not None:
273
            # use negatives from memory bank
274
            negatives = negatives.to(device)
275
276
            # sim_pos is of shape (batch_size, 1) and sim_pos[i] denotes the similarity
277
            # of the i-th sample in the batch to its positive pair
278
            sim_pos = torch.einsum("nc,nc->n", out0, out1).unsqueeze(-1)
279
280
            # sim_neg is of shape (batch_size, memory_bank_size) and sim_neg[i,j] denotes the similarity
281
            # of the i-th sample to the j-th negative sample
282
            sim_neg = torch.einsum("nc,ck->nk", out0, negatives)
283
284
            # set the labels to the first "class", i.e. sim_pos,
285
            # so that it is maximized in relation to sim_neg
286
            logits = torch.cat([sim_pos, sim_neg], dim=1) / self.temperature
287
            labels = torch.zeros(logits.shape[0], device=device, dtype=torch.long)
288
289
        else:
290
            # user other samples from batch as negatives
291
            # and create diagonal mask that only selects similarities between
292
            # views of the same image
293
            
294
            # single process
295
            out0_large = out0
296
            out1_large = out1
297
            diag_mask = torch.eye(batch_size, device=out0.device, dtype=torch.bool)
298
299
            # calculate similiarities
300
            # here n = batch_size and m = batch_size * world_size
301
            # the resulting vectors have shape (n, m)
302
            logits_00 = torch.einsum("nc,mc->nm", out0, out0_large) / self.temperature
303
            logits_01 = torch.einsum("nc,mc->nm", out0, out1_large) / self.temperature
304
            logits_10 = torch.einsum("nc,mc->nm", out1, out0_large) / self.temperature
305
            logits_11 = torch.einsum("nc,mc->nm", out1, out1_large) / self.temperature
306
307
            # remove simliarities between same views of the same image
308
            logits_00 = logits_00[~diag_mask].view(batch_size, -1)
309
            logits_11 = logits_11[~diag_mask].view(batch_size, -1)
310
311
            # concatenate logits
312
            # the logits tensor in the end has shape (2*n, 2*m-1)
313
            logits_0100 = torch.cat([logits_01, logits_00], dim=1)
314
            logits_1011 = torch.cat([logits_10, logits_11], dim=1)
315
            logits = torch.cat([logits_0100, logits_1011], dim=0)
316
317
            # create labels
318
            labels = torch.arange(batch_size, device=device, dtype=torch.long)
319
            labels = labels.repeat(2)
320
321
        loss = self.cross_entropy(logits, labels)
322
323
        return loss