|
a |
|
b/DnR/dnr.py |
|
|
1 |
""" |
|
|
2 |
We build our architecture on top of the ANs proposed in |
|
|
3 |
|
|
|
4 |
@InProceedings{huang2018and, |
|
|
5 |
title={Unsupervised Deep Learning by Neighbourhood Discovery}, |
|
|
6 |
author={Jiabo Huang, Qi Dong, Shaogang Gong and Xiatian Zhu}, |
|
|
7 |
booktitle={Proceedings of the International Conference on machine learning (ICML)}, |
|
|
8 |
year={2019}, |
|
|
9 |
} |
|
|
10 |
|
|
|
11 |
The code is available online under https://github.com/Raymond-sci/AND |
|
|
12 |
|
|
|
13 |
""" |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
from torch.autograd import Function |
|
|
17 |
|
|
|
18 |
import os |
|
|
19 |
import torchvision.models as models |
|
|
20 |
import math |
|
|
21 |
import torch |
|
|
22 |
import torch.nn.functional as F |
|
|
23 |
import torch.nn as nn |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
def resnet18(pretrained=False): |
|
|
27 |
model = models.resnet18(pretrained=pretrained) |
|
|
28 |
model.fc = Identity() |
|
|
29 |
model.avgpool = Identity() |
|
|
30 |
return model |
|
|
31 |
|
|
|
32 |
def resnet34(pretrained=False): |
|
|
33 |
model = models.resnet34(pretrained=pretrained) |
|
|
34 |
model.fc = Identity() |
|
|
35 |
model.avgpool = Identity() |
|
|
36 |
return model |
|
|
37 |
|
|
|
38 |
|
|
|
39 |
def resnet50(pretrained=False): |
|
|
40 |
model = models.resnet50(pretrained=pretrained) |
|
|
41 |
model.fc = Identity() |
|
|
42 |
model.avgpool = Identity() |
|
|
43 |
return model |
|
|
44 |
|
|
|
45 |
class Identity(nn.Module): |
|
|
46 |
def __init__(self): |
|
|
47 |
super(Identity, self).__init__() |
|
|
48 |
|
|
|
49 |
def forward(self, x): |
|
|
50 |
return x |
|
|
51 |
|
|
|
52 |
|
|
|
53 |
class Backbone(nn.Module): |
|
|
54 |
|
|
|
55 |
def __init__(self, name='resnet18', pretrained=True, freeze_all=False): |
|
|
56 |
super(Backbone, self).__init__() |
|
|
57 |
self.name = name |
|
|
58 |
self.freeze_all = freeze_all |
|
|
59 |
self.pretrained = pretrained |
|
|
60 |
if name == 'resnet18': |
|
|
61 |
self.backbone = resnet18(pretrained=self.pretrained) |
|
|
62 |
if name == 'resnet34': |
|
|
63 |
self.backbone = resnet34(pretrained=self.pretrained) |
|
|
64 |
elif name == 'resnet50': |
|
|
65 |
self.backbone = resnet50(pretrained=self.pretrained) |
|
|
66 |
|
|
|
67 |
if self.freeze_all: |
|
|
68 |
# List all layers (even inside sequential module) |
|
|
69 |
layers = [module for module in self.backbone.modules() if type(module) != nn.Sequential] |
|
|
70 |
for layer in layers: |
|
|
71 |
if hasattr(layer, 'requires_grad_'): |
|
|
72 |
layer.requires_grad_(False) |
|
|
73 |
|
|
|
74 |
def forward(self, x): |
|
|
75 |
return self.backbone(x) |
|
|
76 |
|
|
|
77 |
|
|
|
78 |
class SimpleDecoder(nn.Module): |
|
|
79 |
|
|
|
80 |
def __init__(self, hidden_dimension=512): |
|
|
81 |
super(SimpleDecoder, self).__init__() |
|
|
82 |
|
|
|
83 |
self.conv_up_5 = nn.Conv2d(hidden_dimension, hidden_dimension//2, 3, padding=1) |
|
|
84 |
self.conv_up_4 = nn.Conv2d(hidden_dimension//2, hidden_dimension//4, 3, padding=1) |
|
|
85 |
self.conv_up_3 = nn.Conv2d(hidden_dimension//4, hidden_dimension//8, 3, padding=1) |
|
|
86 |
self.conv_up_2 = nn.Conv2d(hidden_dimension//8, hidden_dimension//16, 3, padding=1) |
|
|
87 |
self.conv_up_1 = nn.Conv2d(hidden_dimension//16, hidden_dimension//32, 5, padding=2) |
|
|
88 |
self.decoder = nn.Conv2d(hidden_dimension//32, 1, 5, padding=2) |
|
|
89 |
|
|
|
90 |
def forward(self, z): |
|
|
91 |
|
|
|
92 |
h = nn.ReLU()(self.conv_up_5(z)) |
|
|
93 |
h = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False)(h) |
|
|
94 |
h = nn.ReLU()(self.conv_up_4(h)) |
|
|
95 |
h = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False)(h) |
|
|
96 |
h = nn.ReLU()(self.conv_up_3(h)) |
|
|
97 |
h = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False)(h) |
|
|
98 |
h = nn.ReLU()(self.conv_up_2(h)) |
|
|
99 |
h = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False)(h) |
|
|
100 |
h = nn.ReLU()(self.conv_up_1(h)) |
|
|
101 |
h = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False)(h) |
|
|
102 |
x_hat = nn.Sigmoid()(self.decoder(h)) |
|
|
103 |
|
|
|
104 |
return x_hat |
|
|
105 |
|
|
|
106 |
|
|
|
107 |
class CAE_DNR(nn.Module): |
|
|
108 |
|
|
|
109 |
def __init__(self, pretrained=True, n_channels=3, hidden_dimension=512, name = 'resnet18',npc_dimension = 256): |
|
|
110 |
super(CAE_DNR, self).__init__() |
|
|
111 |
|
|
|
112 |
self.n_channels = n_channels |
|
|
113 |
self.name = name |
|
|
114 |
self.encoder = Backbone(name= self.name, pretrained=pretrained, freeze_all=False) |
|
|
115 |
|
|
|
116 |
if self.n_channels != self.encoder.backbone.conv1.in_channels: |
|
|
117 |
conv1 = nn.Conv2d(2, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) |
|
|
118 |
data = self.encoder.backbone.conv1.weight.data[:, :2, :, :] # Better than nothing ... ? |
|
|
119 |
self.encoder.backbone.conv1 = conv1 |
|
|
120 |
self.encoder.backbone.conv1.weight.data = data |
|
|
121 |
self.fc = nn.Linear(hidden_dimension,npc_dimension) |
|
|
122 |
self.relu = nn.ReLU(inplace=True) |
|
|
123 |
self.hidden_dimension = hidden_dimension |
|
|
124 |
self.decoder = SimpleDecoder(hidden_dimension=hidden_dimension) |
|
|
125 |
|
|
|
126 |
|
|
|
127 |
|
|
|
128 |
def restore_model(self, paths): |
|
|
129 |
for attr, path in paths.items(): |
|
|
130 |
self._load(attr=attr, path=path) |
|
|
131 |
return self |
|
|
132 |
|
|
|
133 |
def _load(self, attr, path): |
|
|
134 |
if not os.path.exists(path): |
|
|
135 |
print('Unknown path: {}'.format(path)) |
|
|
136 |
if not hasattr(self, attr): |
|
|
137 |
print('No attribute: {}'.format(attr)) |
|
|
138 |
|
|
|
139 |
self.__getattr__(attr).load_state_dict(torch.load(path), strict=True) |
|
|
140 |
|
|
|
141 |
return self |
|
|
142 |
|
|
|
143 |
def forward(self, x, decode=False): |
|
|
144 |
|
|
|
145 |
z = self.encode(x, pool=False) |
|
|
146 |
zb = nn.AvgPool2d(2, 2)(z).squeeze(dim=3).squeeze(dim=2) |
|
|
147 |
zp = self.fc(zb) |
|
|
148 |
zp = self.relu(zp) |
|
|
149 |
if decode: |
|
|
150 |
x_hat = self.decoder(z) |
|
|
151 |
else: |
|
|
152 |
x_hat = None |
|
|
153 |
|
|
|
154 |
return x_hat, zp, zb |
|
|
155 |
|
|
|
156 |
def encode(self, x, pool=False): |
|
|
157 |
h = self.encoder(x) |
|
|
158 |
h = h.view((x.shape[0], -1, 2, 2)) |
|
|
159 |
if pool: |
|
|
160 |
return nn.AvgPool2d(2, 2)(h).squeeze(dim=3).squeeze(dim=2) |
|
|
161 |
else: |
|
|
162 |
return h |
|
|
163 |
|
|
|
164 |
def latent_variable(self, x_in, projectionHead): |
|
|
165 |
_, zp, zb = self.forward(x_in, decode=True) |
|
|
166 |
if projectionHead: |
|
|
167 |
return zp |
|
|
168 |
else: |
|
|
169 |
return zb |
|
|
170 |
|
|
|
171 |
|
|
|
172 |
class NonParametricClassifierOP(Function): |
|
|
173 |
|
|
|
174 |
@staticmethod |
|
|
175 |
def forward(self, x, y, memory, params): |
|
|
176 |
T = params[0].item() |
|
|
177 |
|
|
|
178 |
# inner product |
|
|
179 |
out = torch.mm(x.data, memory.t()) |
|
|
180 |
out.div_(T) # batchSize * N |
|
|
181 |
|
|
|
182 |
self.save_for_backward(x, memory, y, params) |
|
|
183 |
|
|
|
184 |
return out |
|
|
185 |
|
|
|
186 |
@staticmethod |
|
|
187 |
def backward(self, gradOutput): |
|
|
188 |
x, memory, y, params = self.saved_tensors |
|
|
189 |
T = params[0].item() |
|
|
190 |
momentum = params[1].item() |
|
|
191 |
|
|
|
192 |
# add temperature |
|
|
193 |
gradOutput.data.div_(T) |
|
|
194 |
|
|
|
195 |
# gradient of linear |
|
|
196 |
gradInput = torch.mm(gradOutput.data, memory) |
|
|
197 |
gradInput.resize_as_(x) |
|
|
198 |
|
|
|
199 |
# update the memory |
|
|
200 |
weight_pos = memory.index_select(0, y.data.view(-1)).resize_as_(x) |
|
|
201 |
weight_pos.mul_(momentum) |
|
|
202 |
weight_pos.add_(torch.mul(x.data, 1 - momentum)) |
|
|
203 |
w_norm = weight_pos.pow(2).sum(1, keepdim=True).pow(0.5) |
|
|
204 |
updated_weight = weight_pos.div(w_norm) |
|
|
205 |
memory.index_copy_(0, y, updated_weight) |
|
|
206 |
|
|
|
207 |
return gradInput, None, None, None, None |
|
|
208 |
|
|
|
209 |
|
|
|
210 |
class NonParametricClassifier(nn.Module): |
|
|
211 |
"""Non-parametric Classifier |
|
|
212 |
|
|
|
213 |
Non-parametric Classifier from |
|
|
214 |
"Unsupervised Feature Learning via Non-Parametric Instance Discrimination" |
|
|
215 |
|
|
|
216 |
Extends: |
|
|
217 |
nn.Module |
|
|
218 |
""" |
|
|
219 |
|
|
|
220 |
def __init__(self, inputSize, outputSize, T=0.05, momentum=0.5): |
|
|
221 |
"""Non-parametric Classifier initial functin |
|
|
222 |
|
|
|
223 |
Initial function for non-parametric classifier |
|
|
224 |
|
|
|
225 |
Arguments: |
|
|
226 |
inputSize {int} -- in-channels dims |
|
|
227 |
outputSize {int} -- out-channels dims |
|
|
228 |
|
|
|
229 |
Keyword Arguments: |
|
|
230 |
T {int} -- distribution temperate (default: {0.05}) |
|
|
231 |
momentum {int} -- memory update momentum (default: {0.5}) |
|
|
232 |
""" |
|
|
233 |
super(NonParametricClassifier, self).__init__() |
|
|
234 |
self.nLem = outputSize |
|
|
235 |
self.register_buffer('params', torch.tensor([T, momentum])) |
|
|
236 |
stdv = 1. / math.sqrt(inputSize / 3) |
|
|
237 |
self.register_buffer('memory', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv)) |
|
|
238 |
|
|
|
239 |
def forward(self, x, y): |
|
|
240 |
out = NonParametricClassifierOP.apply(x, y, self.memory, self.params) |
|
|
241 |
return out |
|
|
242 |
|
|
|
243 |
class ANsDiscovery(nn.Module): |
|
|
244 |
"""Discovery ANs |
|
|
245 |
Discovery ANs according to current round, select_rate and most importantly, |
|
|
246 |
all sample's corresponding entropy |
|
|
247 |
""" |
|
|
248 |
|
|
|
249 |
def __init__(self, nsamples): |
|
|
250 |
"""Object used to discovery ANs |
|
|
251 |
Discovery ANs according to the total amount of samples, ANs selection |
|
|
252 |
rate, ANs size |
|
|
253 |
Arguments: |
|
|
254 |
nsamples {int} -- total number of sampels |
|
|
255 |
select_rate {float} -- ANs selection rate |
|
|
256 |
ans_size {int} -- ANs size |
|
|
257 |
Keyword Arguments: |
|
|
258 |
device {str} -- [description] (default: {'cpu'}) |
|
|
259 |
""" |
|
|
260 |
super(ANsDiscovery, self).__init__() |
|
|
261 |
|
|
|
262 |
# not going to use ``register_buffer'' as |
|
|
263 |
# they are determined by configs |
|
|
264 |
self.select_rate = 0.25 |
|
|
265 |
self.ANs_size = 1 |
|
|
266 |
# number of samples |
|
|
267 |
self.register_buffer('samples_num', torch.tensor(nsamples)) |
|
|
268 |
# indexes list of anchor samples |
|
|
269 |
self.register_buffer('anchor_indexes', torch.LongTensor(nsamples//2)) |
|
|
270 |
# indexes list of instance samples |
|
|
271 |
self.register_buffer('instance_indexes', torch.arange(nsamples//2).long()) |
|
|
272 |
# anchor samples' and instance samples' position |
|
|
273 |
self.register_buffer('position', -1 * torch.arange(nsamples).long() - 1) |
|
|
274 |
# anchor samples' neighbours |
|
|
275 |
self.register_buffer('neighbours', torch.LongTensor(nsamples//2, 1)) |
|
|
276 |
# each sample's entropy |
|
|
277 |
self.register_buffer('entropy', torch.FloatTensor(nsamples)) |
|
|
278 |
# consistency |
|
|
279 |
self.register_buffer('consistency', torch.tensor(0.)) |
|
|
280 |
|
|
|
281 |
|
|
|
282 |
def get_ANs_num(self, round): |
|
|
283 |
"""Get number of ANs |
|
|
284 |
Get number of ANs at target round according to the select rate |
|
|
285 |
Arguments: |
|
|
286 |
round {int} -- target round |
|
|
287 |
Returns: |
|
|
288 |
int -- number of ANs |
|
|
289 |
""" |
|
|
290 |
return int(self.samples_num.float() * self.select_rate * round) |
|
|
291 |
|
|
|
292 |
def update(self, round, npc, cheat_labels=None): |
|
|
293 |
"""Update ANs |
|
|
294 |
Discovery new ANs and update `anchor_indexes`, `instance_indexes` and |
|
|
295 |
`neighbours` |
|
|
296 |
Arguments: |
|
|
297 |
round {int} -- target round |
|
|
298 |
npc {Module} -- non-parametric classifier |
|
|
299 |
cheat_labels {list} -- used to compute consistency of chosen ANs only |
|
|
300 |
Returns: |
|
|
301 |
number -- [updated consistency] |
|
|
302 |
""" |
|
|
303 |
with torch.no_grad(): |
|
|
304 |
batch_size = 100 |
|
|
305 |
ANs_num = self.get_ANs_num(round) |
|
|
306 |
features = npc.memory |
|
|
307 |
|
|
|
308 |
for start in range(0, self.samples_num, batch_size): |
|
|
309 |
end = start + batch_size |
|
|
310 |
end = min(end, self.samples_num) |
|
|
311 |
|
|
|
312 |
preds = F.softmax(npc(features[start:end], None), 1) |
|
|
313 |
self.entropy[start:end] = -(preds * preds.log()).sum(1) |
|
|
314 |
|
|
|
315 |
# get the anchor list and instance list according to the computed |
|
|
316 |
# entropy |
|
|
317 |
self.anchor_indexes = self.entropy.topk(ANs_num, largest=False)[1] |
|
|
318 |
self.instance_indexes = (torch.ones_like(self.position) |
|
|
319 |
.scatter_(0, self.anchor_indexes, 0) |
|
|
320 |
.nonzero().view(-1)) |
|
|
321 |
anchor_entropy = self.entropy.index_select(0, self.anchor_indexes) |
|
|
322 |
instance_entropy = self.entropy.index_select(0, self.instance_indexes) |
|
|
323 |
|
|
|
324 |
# get position |
|
|
325 |
# if the anchor sample x whose index is i while position is j, then |
|
|
326 |
# sample x_i is the j-th anchor sample at current round |
|
|
327 |
# if the instance sample x whose index is i while position is j, then |
|
|
328 |
# sample x_i is the (-j-1)-th instance sample at current round |
|
|
329 |
|
|
|
330 |
instance_cnt = 0 |
|
|
331 |
for i in range(self.samples_num): |
|
|
332 |
|
|
|
333 |
# for anchor samples |
|
|
334 |
if (i == self.anchor_indexes).any(): |
|
|
335 |
self.position[i] = (self.anchor_indexes == i).max(0)[1] |
|
|
336 |
continue |
|
|
337 |
# for instance samples |
|
|
338 |
instance_cnt -= 1 |
|
|
339 |
self.position[i] = instance_cnt |
|
|
340 |
|
|
|
341 |
anchor_features = features.index_select(0, self.anchor_indexes) |
|
|
342 |
self.neighbours = (torch.LongTensor(ANs_num, self.ANs_size) |
|
|
343 |
.to('cuda')) |
|
|
344 |
for start in range(0, ANs_num, batch_size): |
|
|
345 |
|
|
|
346 |
end = start + batch_size |
|
|
347 |
end = min(end, ANs_num) |
|
|
348 |
|
|
|
349 |
sims = torch.mm(anchor_features[start:end], features.t()) |
|
|
350 |
sims.scatter_(1, self.anchor_indexes[start:end].view(-1, 1), -1.) |
|
|
351 |
_, self.neighbours[start:end] = ( |
|
|
352 |
sims.topk(self.ANs_size, largest=True, dim=1)) |
|
|
353 |
|
|
|
354 |
# if cheat labels is provided, then compute consistency |
|
|
355 |
if cheat_labels is None: |
|
|
356 |
return 0. |
|
|
357 |
anchor_label = cheat_labels.index_select(0, self.anchor_indexes) |
|
|
358 |
neighbour_label = cheat_labels.index_select(0, |
|
|
359 |
self.neighbours.view(-1)).view_as(self.neighbours) |
|
|
360 |
self.consistency = ((anchor_label.view(-1, 1) == neighbour_label) |
|
|
361 |
.float().mean()) |
|
|
362 |
|
|
|
363 |
return self.consistency |
|
|
364 |
|
|
|
365 |
|
|
|
366 |
class Criterion(nn.Module): |
|
|
367 |
|
|
|
368 |
def __init__(self): |
|
|
369 |
super(Criterion, self).__init__() |
|
|
370 |
|
|
|
371 |
def calculate_loss(self, x, y, ANs): |
|
|
372 |
batch_size, _ = x.shape |
|
|
373 |
|
|
|
374 |
# split anchor and instance list |
|
|
375 |
anchor_indexes, instance_indexes = self._split(y[:batch_size//2], ANs) |
|
|
376 |
preds = F.softmax(x, 1) |
|
|
377 |
|
|
|
378 |
l_ans = torch.tensor(0).cuda() |
|
|
379 |
if anchor_indexes.size(0) > 0: |
|
|
380 |
# compute loss for anchor samples |
|
|
381 |
y_ans = y.index_select(0, anchor_indexes) |
|
|
382 |
y_ans_p = y.index_select(0, anchor_indexes + batch_size//2) |
|
|
383 |
y_ans_neighbour = ANs.position.index_select(0, y_ans) |
|
|
384 |
neighbours = ANs.neighbours.index_select(0, y_ans_neighbour) |
|
|
385 |
# p_i = \sum_{j \in \Omega_i} p_{i,j} |
|
|
386 |
x_ans = preds.index_select(0, anchor_indexes) |
|
|
387 |
x_ans_p = preds.index_select(0, anchor_indexes + batch_size//2) |
|
|
388 |
|
|
|
389 |
x_ans_neighbour = x_ans.gather(1, neighbours).sum(1) |
|
|
390 |
x_ans_p = x_ans_p.gather(1, y_ans_p.view(-1, 1)).view(-1) |
|
|
391 |
x_ans = x_ans.gather(1, y_ans.view(-1, 1)).view(-1) |
|
|
392 |
# sum all terms : self + sim + neighbors |
|
|
393 |
# NLL: l = -log(p_i) |
|
|
394 |
l_ans = -1 * torch.log(x_ans + x_ans_p + x_ans_neighbour).sum(0) |
|
|
395 |
|
|
|
396 |
l_inst = torch.tensor(0).cuda() |
|
|
397 |
if instance_indexes.size(0) > 0: |
|
|
398 |
# compute loss for instance samples |
|
|
399 |
y_inst = y.index_select(0, instance_indexes) |
|
|
400 |
y_inst_p = y.index_select(0, instance_indexes + batch_size//2) |
|
|
401 |
x_inst = preds.index_select(0, instance_indexes) |
|
|
402 |
x_inst_p = preds.index_select(0, instance_indexes + batch_size//2) |
|
|
403 |
# p_i = p_{i, i} |
|
|
404 |
x_inst = x_inst.gather(1, y_inst.view(-1, 1)) |
|
|
405 |
x_inst_p = x_inst_p.gather(1, y_inst_p.view(-1, 1)) |
|
|
406 |
# NLL: l = -log(p_i) |
|
|
407 |
l_inst = -1 * torch.log(x_inst + x_inst_p).sum(0) |
|
|
408 |
|
|
|
409 |
return l_inst / batch_size, l_ans / batch_size |
|
|
410 |
|
|
|
411 |
def _split(self, y, ANs): |
|
|
412 |
pos = ANs.position.index_select(0, y.view(-1)) |
|
|
413 |
return (pos >= 0).nonzero().view(-1), (pos < 0).nonzero().view(-1) |
|
|
414 |
|
|
|
415 |
def forward(self, x_out, index, npc, ANs_discovery, x_hat, zp): |
|
|
416 |
|
|
|
417 |
z_n = torch.div(zp, torch.norm(zp+1e-12, p=2, dim=1, keepdim=True)) |
|
|
418 |
outputs = npc(z_n, index) # For each image get similarity with neighbour |
|
|
419 |
loss_inst, loss_ans = self.calculate_loss(outputs, index, ANs_discovery) |
|
|
420 |
loss = loss_inst + loss_ans |
|
|
421 |
l_loss = {'loss': loss, 'loss_inst': loss_inst, 'loss_ans': loss_ans} |
|
|
422 |
|
|
|
423 |
if x_hat is not None: |
|
|
424 |
loss_mse = nn.MSELoss()(x_hat, x_out) |
|
|
425 |
loss = loss + loss_mse |
|
|
426 |
l_loss['loss_mse'] = loss_mse |
|
|
427 |
l_loss['loss'] = loss |
|
|
428 |
|
|
|
429 |
return l_loss |