a b/DnR/dnr.py
1
"""
2
We build our architecture on top of the ANs proposed in
3
4
@InProceedings{huang2018and,
5
  title={Unsupervised Deep Learning by Neighbourhood Discovery},
6
  author={Jiabo Huang, Qi Dong, Shaogang Gong and Xiatian Zhu},
7
  booktitle={Proceedings of the International Conference on machine learning (ICML)},
8
  year={2019},
9
}
10
11
The code is available online under https://github.com/Raymond-sci/AND
12
13
"""
14
15
16
from torch.autograd import Function
17
18
import os
19
import torchvision.models as models
20
import math
21
import torch
22
import torch.nn.functional as F
23
import torch.nn as nn
24
25
26
def resnet18(pretrained=False):
27
    model = models.resnet18(pretrained=pretrained)
28
    model.fc = Identity()
29
    model.avgpool = Identity()
30
    return model
31
32
def resnet34(pretrained=False):
33
    model = models.resnet34(pretrained=pretrained)
34
    model.fc = Identity()
35
    model.avgpool = Identity()
36
    return model
37
38
39
def resnet50(pretrained=False):
40
    model = models.resnet50(pretrained=pretrained)
41
    model.fc = Identity()
42
    model.avgpool = Identity()
43
    return model
44
45
class Identity(nn.Module):
46
    def __init__(self):
47
        super(Identity, self).__init__()
48
49
    def forward(self, x):
50
        return x
51
52
53
class Backbone(nn.Module):
54
55
    def __init__(self, name='resnet18', pretrained=True, freeze_all=False):
56
        super(Backbone, self).__init__()
57
        self.name = name
58
        self.freeze_all = freeze_all
59
        self.pretrained = pretrained
60
        if name == 'resnet18':
61
            self.backbone = resnet18(pretrained=self.pretrained)
62
        if name == 'resnet34':
63
            self.backbone = resnet34(pretrained=self.pretrained)  
64
        elif name == 'resnet50':
65
            self.backbone = resnet50(pretrained=self.pretrained)
66
67
        if self.freeze_all:
68
            # List all layers (even inside sequential module)
69
            layers = [module for module in self.backbone.modules() if type(module) != nn.Sequential]
70
            for layer in layers:
71
                if hasattr(layer, 'requires_grad_'):
72
                    layer.requires_grad_(False)
73
74
    def forward(self, x):
75
        return self.backbone(x)
76
77
78
class SimpleDecoder(nn.Module):
79
80
    def __init__(self, hidden_dimension=512):
81
        super(SimpleDecoder, self).__init__()
82
83
        self.conv_up_5 = nn.Conv2d(hidden_dimension, hidden_dimension//2, 3, padding=1)
84
        self.conv_up_4 = nn.Conv2d(hidden_dimension//2, hidden_dimension//4, 3, padding=1)
85
        self.conv_up_3 = nn.Conv2d(hidden_dimension//4, hidden_dimension//8, 3, padding=1)
86
        self.conv_up_2 = nn.Conv2d(hidden_dimension//8, hidden_dimension//16, 3, padding=1)
87
        self.conv_up_1 = nn.Conv2d(hidden_dimension//16, hidden_dimension//32, 5, padding=2)
88
        self.decoder = nn.Conv2d(hidden_dimension//32, 1, 5, padding=2)
89
90
    def forward(self, z):
91
92
        h = nn.ReLU()(self.conv_up_5(z))
93
        h = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False)(h)
94
        h = nn.ReLU()(self.conv_up_4(h))
95
        h = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False)(h)
96
        h = nn.ReLU()(self.conv_up_3(h))
97
        h = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False)(h)
98
        h = nn.ReLU()(self.conv_up_2(h))
99
        h = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False)(h)
100
        h = nn.ReLU()(self.conv_up_1(h))
101
        h = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False)(h)
102
        x_hat = nn.Sigmoid()(self.decoder(h))
103
104
        return x_hat
105
106
107
class CAE_DNR(nn.Module):
108
109
    def __init__(self, pretrained=True, n_channels=3, hidden_dimension=512, name = 'resnet18',npc_dimension = 256):
110
        super(CAE_DNR, self).__init__()
111
112
        self.n_channels = n_channels
113
        self.name = name
114
        self.encoder = Backbone(name= self.name, pretrained=pretrained, freeze_all=False)
115
116
        if self.n_channels != self.encoder.backbone.conv1.in_channels:
117
            conv1 = nn.Conv2d(2, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
118
            data = self.encoder.backbone.conv1.weight.data[:, :2, :, :]  # Better than nothing ... ?
119
            self.encoder.backbone.conv1 = conv1
120
            self.encoder.backbone.conv1.weight.data = data
121
            self.fc = nn.Linear(hidden_dimension,npc_dimension)
122
        self.relu = nn.ReLU(inplace=True)
123
        self.hidden_dimension = hidden_dimension
124
        self.decoder = SimpleDecoder(hidden_dimension=hidden_dimension)
125
126
        
127
128
    def restore_model(self, paths):
129
        for attr, path in paths.items():
130
            self._load(attr=attr, path=path)
131
        return self
132
133
    def _load(self, attr, path):
134
        if not os.path.exists(path):
135
            print('Unknown path: {}'.format(path))
136
        if not hasattr(self, attr):
137
            print('No attribute: {}'.format(attr))
138
139
        self.__getattr__(attr).load_state_dict(torch.load(path), strict=True)
140
141
        return self
142
143
    def forward(self, x, decode=False):
144
145
        z = self.encode(x, pool=False)
146
        zb = nn.AvgPool2d(2, 2)(z).squeeze(dim=3).squeeze(dim=2)
147
        zp = self.fc(zb)
148
        zp = self.relu(zp)
149
        if decode:
150
            x_hat = self.decoder(z)
151
        else:
152
            x_hat = None
153
154
        return x_hat, zp, zb
155
156
    def encode(self, x, pool=False):
157
        h = self.encoder(x)
158
        h = h.view((x.shape[0], -1, 2, 2))
159
        if pool:
160
            return nn.AvgPool2d(2, 2)(h).squeeze(dim=3).squeeze(dim=2)
161
        else:
162
            return h
163
    
164
    def latent_variable(self, x_in, projectionHead):
165
        _, zp, zb = self.forward(x_in, decode=True)
166
        if projectionHead:
167
            return zp
168
        else:
169
            return zb
170
171
172
class NonParametricClassifierOP(Function):
173
174
    @staticmethod
175
    def forward(self, x, y, memory, params):
176
        T = params[0].item()
177
178
        # inner product
179
        out = torch.mm(x.data, memory.t())
180
        out.div_(T)  # batchSize * N
181
182
        self.save_for_backward(x, memory, y, params)
183
184
        return out
185
186
    @staticmethod
187
    def backward(self, gradOutput):
188
        x, memory, y, params = self.saved_tensors
189
        T = params[0].item()
190
        momentum = params[1].item()
191
192
        # add temperature
193
        gradOutput.data.div_(T)
194
195
        # gradient of linear
196
        gradInput = torch.mm(gradOutput.data, memory)
197
        gradInput.resize_as_(x)
198
199
        # update the memory
200
        weight_pos = memory.index_select(0, y.data.view(-1)).resize_as_(x)
201
        weight_pos.mul_(momentum)
202
        weight_pos.add_(torch.mul(x.data, 1 - momentum))
203
        w_norm = weight_pos.pow(2).sum(1, keepdim=True).pow(0.5)
204
        updated_weight = weight_pos.div(w_norm)
205
        memory.index_copy_(0, y, updated_weight)
206
207
        return gradInput, None, None, None, None
208
209
210
class NonParametricClassifier(nn.Module):
211
    """Non-parametric Classifier
212
213
    Non-parametric Classifier from
214
    "Unsupervised Feature Learning via Non-Parametric Instance Discrimination"
215
216
    Extends:
217
        nn.Module
218
    """
219
220
    def __init__(self, inputSize, outputSize, T=0.05, momentum=0.5):
221
        """Non-parametric Classifier initial functin
222
223
        Initial function for non-parametric classifier
224
225
        Arguments:
226
            inputSize {int} -- in-channels dims
227
            outputSize {int} -- out-channels dims
228
229
        Keyword Arguments:
230
            T {int} -- distribution temperate (default: {0.05})
231
            momentum {int} -- memory update momentum (default: {0.5})
232
        """
233
        super(NonParametricClassifier, self).__init__()
234
        self.nLem = outputSize
235
        self.register_buffer('params', torch.tensor([T, momentum]))
236
        stdv = 1. / math.sqrt(inputSize / 3)
237
        self.register_buffer('memory', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv))
238
239
    def forward(self, x, y):
240
        out = NonParametricClassifierOP.apply(x, y, self.memory, self.params)
241
        return out
242
243
class ANsDiscovery(nn.Module):
244
    """Discovery ANs
245
    Discovery ANs according to current round, select_rate and most importantly,
246
    all sample's corresponding entropy
247
    """
248
249
    def __init__(self, nsamples):
250
        """Object used to discovery ANs
251
        Discovery ANs according to the total amount of samples, ANs selection
252
        rate, ANs size
253
        Arguments:
254
            nsamples {int} -- total number of sampels
255
            select_rate {float} -- ANs selection rate
256
            ans_size {int} -- ANs size
257
        Keyword Arguments:
258
            device {str} -- [description] (default: {'cpu'})
259
        """
260
        super(ANsDiscovery, self).__init__()
261
262
        # not going to use ``register_buffer'' as
263
        # they are determined by configs
264
        self.select_rate = 0.25
265
        self.ANs_size = 1
266
        # number of samples
267
        self.register_buffer('samples_num', torch.tensor(nsamples))
268
        # indexes list of anchor samples
269
        self.register_buffer('anchor_indexes', torch.LongTensor(nsamples//2))
270
        # indexes list of instance samples
271
        self.register_buffer('instance_indexes', torch.arange(nsamples//2).long())
272
        # anchor samples' and instance samples' position
273
        self.register_buffer('position', -1 * torch.arange(nsamples).long() - 1)
274
        # anchor samples' neighbours
275
        self.register_buffer('neighbours', torch.LongTensor(nsamples//2, 1))
276
        # each sample's entropy
277
        self.register_buffer('entropy', torch.FloatTensor(nsamples))
278
        # consistency
279
        self.register_buffer('consistency', torch.tensor(0.))
280
281
282
    def get_ANs_num(self, round):
283
        """Get number of ANs
284
        Get number of ANs at target round according to the select rate
285
        Arguments:
286
            round {int} -- target round
287
        Returns:
288
            int -- number of ANs
289
        """
290
        return int(self.samples_num.float() * self.select_rate * round)
291
292
    def update(self, round, npc, cheat_labels=None):
293
        """Update ANs
294
        Discovery new ANs and update `anchor_indexes`, `instance_indexes` and
295
        `neighbours`
296
        Arguments:
297
            round {int} -- target round
298
            npc {Module} -- non-parametric classifier
299
            cheat_labels {list} -- used to compute consistency of chosen ANs only
300
        Returns:
301
            number -- [updated consistency]
302
        """
303
        with torch.no_grad():
304
            batch_size = 100
305
            ANs_num = self.get_ANs_num(round)
306
            features = npc.memory
307
308
            for start in range(0, self.samples_num, batch_size):
309
                end = start + batch_size
310
                end = min(end, self.samples_num)
311
312
                preds = F.softmax(npc(features[start:end], None), 1)
313
                self.entropy[start:end] = -(preds * preds.log()).sum(1)
314
315
            # get the anchor list and instance list according to the computed
316
            # entropy
317
            self.anchor_indexes = self.entropy.topk(ANs_num, largest=False)[1]
318
            self.instance_indexes = (torch.ones_like(self.position)
319
                                     .scatter_(0, self.anchor_indexes, 0)
320
                                     .nonzero().view(-1))
321
            anchor_entropy = self.entropy.index_select(0, self.anchor_indexes)
322
            instance_entropy = self.entropy.index_select(0, self.instance_indexes)
323
324
            # get position
325
            # if the anchor sample x whose index is i while position is j, then
326
            # sample x_i is the j-th anchor sample at current round
327
            # if the instance sample x whose index is i while position is j, then
328
            # sample x_i is the (-j-1)-th instance sample at current round
329
330
            instance_cnt = 0
331
            for i in range(self.samples_num):
332
333
                # for anchor samples
334
                if (i == self.anchor_indexes).any():
335
                    self.position[i] = (self.anchor_indexes == i).max(0)[1]
336
                    continue
337
                # for instance samples
338
                instance_cnt -= 1
339
                self.position[i] = instance_cnt
340
341
            anchor_features = features.index_select(0, self.anchor_indexes)
342
            self.neighbours = (torch.LongTensor(ANs_num, self.ANs_size)
343
                               .to('cuda'))
344
            for start in range(0, ANs_num, batch_size):
345
346
                end = start + batch_size
347
                end = min(end, ANs_num)
348
349
                sims = torch.mm(anchor_features[start:end], features.t())
350
                sims.scatter_(1, self.anchor_indexes[start:end].view(-1, 1), -1.)
351
                _, self.neighbours[start:end] = (
352
                    sims.topk(self.ANs_size, largest=True, dim=1))
353
354
            # if cheat labels is provided, then compute consistency
355
            if cheat_labels is None:
356
                return 0.
357
            anchor_label = cheat_labels.index_select(0, self.anchor_indexes)
358
            neighbour_label = cheat_labels.index_select(0,
359
                                                        self.neighbours.view(-1)).view_as(self.neighbours)
360
            self.consistency = ((anchor_label.view(-1, 1) == neighbour_label)
361
                                .float().mean())
362
363
            return self.consistency
364
365
366
class Criterion(nn.Module):
367
368
    def __init__(self):
369
        super(Criterion, self).__init__()
370
371
    def calculate_loss(self, x, y, ANs):
372
        batch_size, _ = x.shape
373
374
        # split anchor and instance list
375
        anchor_indexes, instance_indexes = self._split(y[:batch_size//2], ANs)
376
        preds = F.softmax(x, 1)
377
378
        l_ans = torch.tensor(0).cuda()
379
        if anchor_indexes.size(0) > 0:
380
            # compute loss for anchor samples
381
            y_ans = y.index_select(0, anchor_indexes)
382
            y_ans_p = y.index_select(0, anchor_indexes + batch_size//2)
383
            y_ans_neighbour = ANs.position.index_select(0, y_ans)
384
            neighbours = ANs.neighbours.index_select(0, y_ans_neighbour)
385
            # p_i = \sum_{j \in \Omega_i} p_{i,j}
386
            x_ans = preds.index_select(0, anchor_indexes)
387
            x_ans_p = preds.index_select(0, anchor_indexes + batch_size//2)
388
389
            x_ans_neighbour = x_ans.gather(1, neighbours).sum(1)
390
            x_ans_p = x_ans_p.gather(1, y_ans_p.view(-1, 1)).view(-1)
391
            x_ans = x_ans.gather(1, y_ans.view(-1, 1)).view(-1)
392
            # sum all terms : self + sim + neighbors
393
            # NLL: l = -log(p_i)
394
            l_ans = -1 * torch.log(x_ans + x_ans_p + x_ans_neighbour).sum(0)
395
396
        l_inst = torch.tensor(0).cuda()
397
        if instance_indexes.size(0) > 0:
398
            # compute loss for instance samples
399
            y_inst = y.index_select(0, instance_indexes)
400
            y_inst_p = y.index_select(0, instance_indexes + batch_size//2)
401
            x_inst = preds.index_select(0, instance_indexes)
402
            x_inst_p = preds.index_select(0, instance_indexes + batch_size//2)
403
            # p_i = p_{i, i}
404
            x_inst = x_inst.gather(1, y_inst.view(-1, 1))
405
            x_inst_p = x_inst_p.gather(1, y_inst_p.view(-1, 1))
406
            # NLL: l = -log(p_i)
407
            l_inst = -1 * torch.log(x_inst + x_inst_p).sum(0)
408
409
        return l_inst / batch_size, l_ans / batch_size
410
411
    def _split(self, y, ANs):
412
        pos = ANs.position.index_select(0, y.view(-1))
413
        return (pos >= 0).nonzero().view(-1), (pos < 0).nonzero().view(-1)
414
    
415
    def forward(self, x_out, index, npc, ANs_discovery, x_hat, zp):
416
417
        z_n = torch.div(zp, torch.norm(zp+1e-12, p=2, dim=1, keepdim=True))
418
        outputs = npc(z_n, index)  # For each image get similarity with neighbour
419
        loss_inst, loss_ans = self.calculate_loss(outputs, index, ANs_discovery)
420
        loss = loss_inst + loss_ans
421
        l_loss = {'loss': loss, 'loss_inst': loss_inst, 'loss_ans': loss_ans}
422
423
        if x_hat is not None:
424
            loss_mse = nn.MSELoss()(x_hat, x_out)
425
            loss = loss + loss_mse
426
            l_loss['loss_mse'] = loss_mse
427
            l_loss['loss'] = loss
428
429
        return l_loss