[a18f15]: / algorithms / loss / ssl_losses.py

Download this file

323 lines (254 with data), 11.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
""" Negative Cosine Similarity Loss Function """
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
import torch
from torch.nn.functional import cosine_similarity
class NegativeCosineSimilarity(torch.nn.Module):
"""Implementation of the Negative Cosine Simililarity used in the SimSiam[0] paper.
[0] SimSiam, 2020, https://arxiv.org/abs/2011.10566
Examples:
>>> # initialize loss function
>>> loss_fn = NegativeCosineSimilarity()
>>>
>>> # generate two representation tensors
>>> # with batch size 10 and dimension 128
>>> x0 = torch.randn(10, 128)
>>> x1 = torch.randn(10, 128)
>>>
>>> # calculate loss
>>> loss = loss_fn(x0, x1)
"""
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
"""Same parameters as in torch.nn.CosineSimilarity
Args:
dim (int, optional):
Dimension where cosine similarity is computed. Default: 1
eps (float, optional):
Small value to avoid division by zero. Default: 1e-8
"""
super().__init__()
self.dim = dim
self.eps = eps
def forward(self, x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
return -cosine_similarity(x0, x1, self.dim, self.eps).mean()
""" Memory Bank Wrapper """
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
import functools
class MemoryBankModule(torch.nn.Module):
"""Memory bank implementation
This is a parent class to all loss functions implemented by the lightly
Python package. This way, any loss can be used with a memory bank if
desired.
Attributes:
size:
Number of keys the memory bank can store. If set to 0,
memory bank is not used.
Examples:
>>> class MyLossFunction(MemoryBankModule):
>>>
>>> def __init__(self, memory_bank_size: int = 2 ** 16):
>>> super(MyLossFunction, self).__init__(memory_bank_size)
>>>
>>> def forward(self, output: torch.Tensor,
>>> labels: torch.Tensor = None):
>>>
>>> output, negatives = super(
>>> MyLossFunction, self).forward(output)
>>>
>>> if negatives is not None:
>>> # evaluate loss with negative samples
>>> else:
>>> # evaluate loss without negative samples
"""
def __init__(self, size: int = 2**16):
super(MemoryBankModule, self).__init__()
if size < 0:
msg = f"Illegal memory bank size {size}, must be non-negative."
raise ValueError(msg)
self.size = size
self.register_buffer(
"bank", tensor=torch.empty(0, dtype=torch.float), persistent=False
)
self.register_buffer(
"bank_ptr", tensor=torch.empty(0, dtype=torch.long), persistent=False
)
@torch.no_grad()
def _init_memory_bank(self, dim: int):
"""Initialize the memory bank if it's empty
Args:
dim:
The dimension of the which are stored in the bank.
"""
# create memory bank
# we could use register buffers like in the moco repo
# https://github.com/facebookresearch/moco but we don't
# want to pollute our checkpoints
self.bank = torch.randn(dim, self.size).type_as(self.bank)
self.bank = torch.nn.functional.normalize(self.bank, dim=0)
self.bank_ptr = torch.zeros(1).type_as(self.bank_ptr)
@torch.no_grad()
def _dequeue_and_enqueue(self, batch: torch.Tensor):
"""Dequeue the oldest batch and add the latest one
Args:
batch:
The latest batch of keys to add to the memory bank.
"""
batch_size = batch.shape[0]
ptr = int(self.bank_ptr)
if ptr + batch_size >= self.size:
self.bank[:, ptr:] = batch[: self.size - ptr].T.detach()
self.bank_ptr[0] = 0
else:
self.bank[:, ptr : ptr + batch_size] = batch.T.detach()
self.bank_ptr[0] = ptr + batch_size
def forward(
self, output: torch.Tensor, labels: torch.Tensor = None, update: bool = False
):
"""Query memory bank for additional negative samples
Args:
output:
The output of the model.
labels:
Should always be None, will be ignored.
Returns:
The output if the memory bank is of size 0, otherwise the output
and the entries from the memory bank.
"""
# no memory bank, return the output
if self.size == 0:
return output, None
_, dim = output.shape
# initialize the memory bank if it is not already done
if self.bank.nelement() == 0:
self._init_memory_bank(dim)
# query and update memory bank
bank = self.bank.clone().detach()
# only update memory bank if we later do backward pass (gradient)
if update:
self._dequeue_and_enqueue(output)
return output, bank
""" Contrastive Loss Functions """
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
# from torch import distributed as torch_dist
# from torch import nn
class NTXentLoss(MemoryBankModule):
"""Implementation of the Contrastive Cross Entropy Loss.
This implementation follows the SimCLR[0] paper. If you enable the memory
bank by setting the `memory_bank_size` value > 0 the loss behaves like
the one described in the MoCo[1] paper.
- [0] SimCLR, 2020, https://arxiv.org/abs/2002.05709
- [1] MoCo, 2020, https://arxiv.org/abs/1911.05722
Attributes:
temperature:
Scale logits by the inverse of the temperature.
memory_bank_size:
Number of negative samples to store in the memory bank.
Use 0 for SimCLR. For MoCo we typically use numbers like 4096 or 65536.
gather_distributed:
If True then negatives from all gpus are gathered before the
loss calculation. This flag has no effect if memory_bank_size > 0.
Raises:
ValueError: If abs(temperature) < 1e-8 to prevent divide by zero.
Examples:
>>> # initialize loss function without memory bank
>>> loss_fn = NTXentLoss(memory_bank_size=0)
>>>
>>> # generate two random transforms of images
>>> t0 = transforms(images)
>>> t1 = transforms(images)
>>>
>>> # feed through SimCLR or MoCo model
>>> batch = torch.cat((t0, t1), dim=0)
>>> output = model(batch)
>>>
>>> # calculate loss
>>> loss = loss_fn(output)
"""
def __init__(
self,
temperature: float = 0.5,
memory_bank_size: int = 4096,
):
super(NTXentLoss, self).__init__(size=memory_bank_size)
self.temperature = temperature
self.cross_entropy = torch.nn.CrossEntropyLoss(reduction="mean")
self.eps = 1e-8
if abs(self.temperature) < self.eps:
raise ValueError(
"Illegal temperature: abs({}) < 1e-8".format(self.temperature)
)
def forward(self, out0: torch.Tensor, out1: torch.Tensor):
"""Forward pass through Contrastive Cross-Entropy Loss.
If used with a memory bank, the samples from the memory bank are used
as negative examples. Otherwise, within-batch samples are used as
negative samples.
Args:
out0:
Output projections of the first set of transformed images.
Shape: (batch_size, embedding_size)
out1:
Output projections of the second set of transformed images.
Shape: (batch_size, embedding_size)
Returns:
Contrastive Cross Entropy Loss value.
"""
device = out0.device
batch_size, _ = out0.shape
# normalize the output to length 1
out0 = torch.nn.functional.normalize(out0, dim=1)
out1 = torch.nn.functional.normalize(out1, dim=1)
# ask memory bank for negative samples and extend it with out1 if
# out1 requires a gradient, otherwise keep the same vectors in the
# memory bank (this allows for keeping the memory bank constant e.g.
# for evaluating the loss on the test set)
# out1: shape: (batch_size, embedding_size)
# negatives: shape: (embedding_size, memory_bank_size)
out1, negatives = super(NTXentLoss, self).forward(
out1, update=out0.requires_grad
)
# We use the cosine similarity, which is a dot product (einsum) here,
# as all vectors are already normalized to unit length.
# Notation in einsum: n = batch_size, c = embedding_size and k = memory_bank_size.
if negatives is not None:
# use negatives from memory bank
negatives = negatives.to(device)
# sim_pos is of shape (batch_size, 1) and sim_pos[i] denotes the similarity
# of the i-th sample in the batch to its positive pair
sim_pos = torch.einsum("nc,nc->n", out0, out1).unsqueeze(-1)
# sim_neg is of shape (batch_size, memory_bank_size) and sim_neg[i,j] denotes the similarity
# of the i-th sample to the j-th negative sample
sim_neg = torch.einsum("nc,ck->nk", out0, negatives)
# set the labels to the first "class", i.e. sim_pos,
# so that it is maximized in relation to sim_neg
logits = torch.cat([sim_pos, sim_neg], dim=1) / self.temperature
labels = torch.zeros(logits.shape[0], device=device, dtype=torch.long)
else:
# user other samples from batch as negatives
# and create diagonal mask that only selects similarities between
# views of the same image
# single process
out0_large = out0
out1_large = out1
diag_mask = torch.eye(batch_size, device=out0.device, dtype=torch.bool)
# calculate similiarities
# here n = batch_size and m = batch_size * world_size
# the resulting vectors have shape (n, m)
logits_00 = torch.einsum("nc,mc->nm", out0, out0_large) / self.temperature
logits_01 = torch.einsum("nc,mc->nm", out0, out1_large) / self.temperature
logits_10 = torch.einsum("nc,mc->nm", out1, out0_large) / self.temperature
logits_11 = torch.einsum("nc,mc->nm", out1, out1_large) / self.temperature
# remove simliarities between same views of the same image
logits_00 = logits_00[~diag_mask].view(batch_size, -1)
logits_11 = logits_11[~diag_mask].view(batch_size, -1)
# concatenate logits
# the logits tensor in the end has shape (2*n, 2*m-1)
logits_0100 = torch.cat([logits_01, logits_00], dim=1)
logits_1011 = torch.cat([logits_10, logits_11], dim=1)
logits = torch.cat([logits_0100, logits_1011], dim=0)
# create labels
labels = torch.arange(batch_size, device=device, dtype=torch.long)
labels = labels.repeat(2)
loss = self.cross_entropy(logits, labels)
return loss