--- a +++ b/DnR/dnr.py @@ -0,0 +1,429 @@ +""" +We build our architecture on top of the ANs proposed in + +@InProceedings{huang2018and, + title={Unsupervised Deep Learning by Neighbourhood Discovery}, + author={Jiabo Huang, Qi Dong, Shaogang Gong and Xiatian Zhu}, + booktitle={Proceedings of the International Conference on machine learning (ICML)}, + year={2019}, +} + +The code is available online under https://github.com/Raymond-sci/AND + +""" + + +from torch.autograd import Function + +import os +import torchvision.models as models +import math +import torch +import torch.nn.functional as F +import torch.nn as nn + + +def resnet18(pretrained=False): + model = models.resnet18(pretrained=pretrained) + model.fc = Identity() + model.avgpool = Identity() + return model + +def resnet34(pretrained=False): + model = models.resnet34(pretrained=pretrained) + model.fc = Identity() + model.avgpool = Identity() + return model + + +def resnet50(pretrained=False): + model = models.resnet50(pretrained=pretrained) + model.fc = Identity() + model.avgpool = Identity() + return model + +class Identity(nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return x + + +class Backbone(nn.Module): + + def __init__(self, name='resnet18', pretrained=True, freeze_all=False): + super(Backbone, self).__init__() + self.name = name + self.freeze_all = freeze_all + self.pretrained = pretrained + if name == 'resnet18': + self.backbone = resnet18(pretrained=self.pretrained) + if name == 'resnet34': + self.backbone = resnet34(pretrained=self.pretrained) + elif name == 'resnet50': + self.backbone = resnet50(pretrained=self.pretrained) + + if self.freeze_all: + # List all layers (even inside sequential module) + layers = [module for module in self.backbone.modules() if type(module) != nn.Sequential] + for layer in layers: + if hasattr(layer, 'requires_grad_'): + layer.requires_grad_(False) + + def forward(self, x): + return self.backbone(x) + + +class SimpleDecoder(nn.Module): + + def __init__(self, hidden_dimension=512): + super(SimpleDecoder, self).__init__() + + self.conv_up_5 = nn.Conv2d(hidden_dimension, hidden_dimension//2, 3, padding=1) + self.conv_up_4 = nn.Conv2d(hidden_dimension//2, hidden_dimension//4, 3, padding=1) + self.conv_up_3 = nn.Conv2d(hidden_dimension//4, hidden_dimension//8, 3, padding=1) + self.conv_up_2 = nn.Conv2d(hidden_dimension//8, hidden_dimension//16, 3, padding=1) + self.conv_up_1 = nn.Conv2d(hidden_dimension//16, hidden_dimension//32, 5, padding=2) + self.decoder = nn.Conv2d(hidden_dimension//32, 1, 5, padding=2) + + def forward(self, z): + + h = nn.ReLU()(self.conv_up_5(z)) + h = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False)(h) + h = nn.ReLU()(self.conv_up_4(h)) + h = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False)(h) + h = nn.ReLU()(self.conv_up_3(h)) + h = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False)(h) + h = nn.ReLU()(self.conv_up_2(h)) + h = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False)(h) + h = nn.ReLU()(self.conv_up_1(h)) + h = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False)(h) + x_hat = nn.Sigmoid()(self.decoder(h)) + + return x_hat + + +class CAE_DNR(nn.Module): + + def __init__(self, pretrained=True, n_channels=3, hidden_dimension=512, name = 'resnet18',npc_dimension = 256): + super(CAE_DNR, self).__init__() + + self.n_channels = n_channels + self.name = name + self.encoder = Backbone(name= self.name, pretrained=pretrained, freeze_all=False) + + if self.n_channels != self.encoder.backbone.conv1.in_channels: + conv1 = nn.Conv2d(2, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) + data = self.encoder.backbone.conv1.weight.data[:, :2, :, :] # Better than nothing ... ? + self.encoder.backbone.conv1 = conv1 + self.encoder.backbone.conv1.weight.data = data + self.fc = nn.Linear(hidden_dimension,npc_dimension) + self.relu = nn.ReLU(inplace=True) + self.hidden_dimension = hidden_dimension + self.decoder = SimpleDecoder(hidden_dimension=hidden_dimension) + + + + def restore_model(self, paths): + for attr, path in paths.items(): + self._load(attr=attr, path=path) + return self + + def _load(self, attr, path): + if not os.path.exists(path): + print('Unknown path: {}'.format(path)) + if not hasattr(self, attr): + print('No attribute: {}'.format(attr)) + + self.__getattr__(attr).load_state_dict(torch.load(path), strict=True) + + return self + + def forward(self, x, decode=False): + + z = self.encode(x, pool=False) + zb = nn.AvgPool2d(2, 2)(z).squeeze(dim=3).squeeze(dim=2) + zp = self.fc(zb) + zp = self.relu(zp) + if decode: + x_hat = self.decoder(z) + else: + x_hat = None + + return x_hat, zp, zb + + def encode(self, x, pool=False): + h = self.encoder(x) + h = h.view((x.shape[0], -1, 2, 2)) + if pool: + return nn.AvgPool2d(2, 2)(h).squeeze(dim=3).squeeze(dim=2) + else: + return h + + def latent_variable(self, x_in, projectionHead): + _, zp, zb = self.forward(x_in, decode=True) + if projectionHead: + return zp + else: + return zb + + +class NonParametricClassifierOP(Function): + + @staticmethod + def forward(self, x, y, memory, params): + T = params[0].item() + + # inner product + out = torch.mm(x.data, memory.t()) + out.div_(T) # batchSize * N + + self.save_for_backward(x, memory, y, params) + + return out + + @staticmethod + def backward(self, gradOutput): + x, memory, y, params = self.saved_tensors + T = params[0].item() + momentum = params[1].item() + + # add temperature + gradOutput.data.div_(T) + + # gradient of linear + gradInput = torch.mm(gradOutput.data, memory) + gradInput.resize_as_(x) + + # update the memory + weight_pos = memory.index_select(0, y.data.view(-1)).resize_as_(x) + weight_pos.mul_(momentum) + weight_pos.add_(torch.mul(x.data, 1 - momentum)) + w_norm = weight_pos.pow(2).sum(1, keepdim=True).pow(0.5) + updated_weight = weight_pos.div(w_norm) + memory.index_copy_(0, y, updated_weight) + + return gradInput, None, None, None, None + + +class NonParametricClassifier(nn.Module): + """Non-parametric Classifier + + Non-parametric Classifier from + "Unsupervised Feature Learning via Non-Parametric Instance Discrimination" + + Extends: + nn.Module + """ + + def __init__(self, inputSize, outputSize, T=0.05, momentum=0.5): + """Non-parametric Classifier initial functin + + Initial function for non-parametric classifier + + Arguments: + inputSize {int} -- in-channels dims + outputSize {int} -- out-channels dims + + Keyword Arguments: + T {int} -- distribution temperate (default: {0.05}) + momentum {int} -- memory update momentum (default: {0.5}) + """ + super(NonParametricClassifier, self).__init__() + self.nLem = outputSize + self.register_buffer('params', torch.tensor([T, momentum])) + stdv = 1. / math.sqrt(inputSize / 3) + self.register_buffer('memory', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv)) + + def forward(self, x, y): + out = NonParametricClassifierOP.apply(x, y, self.memory, self.params) + return out + +class ANsDiscovery(nn.Module): + """Discovery ANs + Discovery ANs according to current round, select_rate and most importantly, + all sample's corresponding entropy + """ + + def __init__(self, nsamples): + """Object used to discovery ANs + Discovery ANs according to the total amount of samples, ANs selection + rate, ANs size + Arguments: + nsamples {int} -- total number of sampels + select_rate {float} -- ANs selection rate + ans_size {int} -- ANs size + Keyword Arguments: + device {str} -- [description] (default: {'cpu'}) + """ + super(ANsDiscovery, self).__init__() + + # not going to use ``register_buffer'' as + # they are determined by configs + self.select_rate = 0.25 + self.ANs_size = 1 + # number of samples + self.register_buffer('samples_num', torch.tensor(nsamples)) + # indexes list of anchor samples + self.register_buffer('anchor_indexes', torch.LongTensor(nsamples//2)) + # indexes list of instance samples + self.register_buffer('instance_indexes', torch.arange(nsamples//2).long()) + # anchor samples' and instance samples' position + self.register_buffer('position', -1 * torch.arange(nsamples).long() - 1) + # anchor samples' neighbours + self.register_buffer('neighbours', torch.LongTensor(nsamples//2, 1)) + # each sample's entropy + self.register_buffer('entropy', torch.FloatTensor(nsamples)) + # consistency + self.register_buffer('consistency', torch.tensor(0.)) + + + def get_ANs_num(self, round): + """Get number of ANs + Get number of ANs at target round according to the select rate + Arguments: + round {int} -- target round + Returns: + int -- number of ANs + """ + return int(self.samples_num.float() * self.select_rate * round) + + def update(self, round, npc, cheat_labels=None): + """Update ANs + Discovery new ANs and update `anchor_indexes`, `instance_indexes` and + `neighbours` + Arguments: + round {int} -- target round + npc {Module} -- non-parametric classifier + cheat_labels {list} -- used to compute consistency of chosen ANs only + Returns: + number -- [updated consistency] + """ + with torch.no_grad(): + batch_size = 100 + ANs_num = self.get_ANs_num(round) + features = npc.memory + + for start in range(0, self.samples_num, batch_size): + end = start + batch_size + end = min(end, self.samples_num) + + preds = F.softmax(npc(features[start:end], None), 1) + self.entropy[start:end] = -(preds * preds.log()).sum(1) + + # get the anchor list and instance list according to the computed + # entropy + self.anchor_indexes = self.entropy.topk(ANs_num, largest=False)[1] + self.instance_indexes = (torch.ones_like(self.position) + .scatter_(0, self.anchor_indexes, 0) + .nonzero().view(-1)) + anchor_entropy = self.entropy.index_select(0, self.anchor_indexes) + instance_entropy = self.entropy.index_select(0, self.instance_indexes) + + # get position + # if the anchor sample x whose index is i while position is j, then + # sample x_i is the j-th anchor sample at current round + # if the instance sample x whose index is i while position is j, then + # sample x_i is the (-j-1)-th instance sample at current round + + instance_cnt = 0 + for i in range(self.samples_num): + + # for anchor samples + if (i == self.anchor_indexes).any(): + self.position[i] = (self.anchor_indexes == i).max(0)[1] + continue + # for instance samples + instance_cnt -= 1 + self.position[i] = instance_cnt + + anchor_features = features.index_select(0, self.anchor_indexes) + self.neighbours = (torch.LongTensor(ANs_num, self.ANs_size) + .to('cuda')) + for start in range(0, ANs_num, batch_size): + + end = start + batch_size + end = min(end, ANs_num) + + sims = torch.mm(anchor_features[start:end], features.t()) + sims.scatter_(1, self.anchor_indexes[start:end].view(-1, 1), -1.) + _, self.neighbours[start:end] = ( + sims.topk(self.ANs_size, largest=True, dim=1)) + + # if cheat labels is provided, then compute consistency + if cheat_labels is None: + return 0. + anchor_label = cheat_labels.index_select(0, self.anchor_indexes) + neighbour_label = cheat_labels.index_select(0, + self.neighbours.view(-1)).view_as(self.neighbours) + self.consistency = ((anchor_label.view(-1, 1) == neighbour_label) + .float().mean()) + + return self.consistency + + +class Criterion(nn.Module): + + def __init__(self): + super(Criterion, self).__init__() + + def calculate_loss(self, x, y, ANs): + batch_size, _ = x.shape + + # split anchor and instance list + anchor_indexes, instance_indexes = self._split(y[:batch_size//2], ANs) + preds = F.softmax(x, 1) + + l_ans = torch.tensor(0).cuda() + if anchor_indexes.size(0) > 0: + # compute loss for anchor samples + y_ans = y.index_select(0, anchor_indexes) + y_ans_p = y.index_select(0, anchor_indexes + batch_size//2) + y_ans_neighbour = ANs.position.index_select(0, y_ans) + neighbours = ANs.neighbours.index_select(0, y_ans_neighbour) + # p_i = \sum_{j \in \Omega_i} p_{i,j} + x_ans = preds.index_select(0, anchor_indexes) + x_ans_p = preds.index_select(0, anchor_indexes + batch_size//2) + + x_ans_neighbour = x_ans.gather(1, neighbours).sum(1) + x_ans_p = x_ans_p.gather(1, y_ans_p.view(-1, 1)).view(-1) + x_ans = x_ans.gather(1, y_ans.view(-1, 1)).view(-1) + # sum all terms : self + sim + neighbors + # NLL: l = -log(p_i) + l_ans = -1 * torch.log(x_ans + x_ans_p + x_ans_neighbour).sum(0) + + l_inst = torch.tensor(0).cuda() + if instance_indexes.size(0) > 0: + # compute loss for instance samples + y_inst = y.index_select(0, instance_indexes) + y_inst_p = y.index_select(0, instance_indexes + batch_size//2) + x_inst = preds.index_select(0, instance_indexes) + x_inst_p = preds.index_select(0, instance_indexes + batch_size//2) + # p_i = p_{i, i} + x_inst = x_inst.gather(1, y_inst.view(-1, 1)) + x_inst_p = x_inst_p.gather(1, y_inst_p.view(-1, 1)) + # NLL: l = -log(p_i) + l_inst = -1 * torch.log(x_inst + x_inst_p).sum(0) + + return l_inst / batch_size, l_ans / batch_size + + def _split(self, y, ANs): + pos = ANs.position.index_select(0, y.view(-1)) + return (pos >= 0).nonzero().view(-1), (pos < 0).nonzero().view(-1) + + def forward(self, x_out, index, npc, ANs_discovery, x_hat, zp): + + z_n = torch.div(zp, torch.norm(zp+1e-12, p=2, dim=1, keepdim=True)) + outputs = npc(z_n, index) # For each image get similarity with neighbour + loss_inst, loss_ans = self.calculate_loss(outputs, index, ANs_discovery) + loss = loss_inst + loss_ans + l_loss = {'loss': loss, 'loss_inst': loss_inst, 'loss_ans': loss_ans} + + if x_hat is not None: + loss_mse = nn.MSELoss()(x_hat, x_out) + loss = loss + loss_mse + l_loss['loss_mse'] = loss_mse + l_loss['loss'] = loss + + return l_loss