a b/src/LRP.py
1
import re
2
import torch
3
from torch.functional import norm
4
import torch.nn as nn
5
from torch.autograd import Variable
6
from torch.nn import Parameter
7
import torch.nn.functional as F
8
import torch.optim as optim
9
from torch.optim.lr_scheduler import *
10
from torch.utils.data import DataLoader, TensorDataset
11
from torch.nn.utils import clip_grad_norm_
12
import numpy as np
13
import math, os
14
15
16
class ClustDistLayer(nn.Module):
17
    def __init__(self, centroids, n_clusters, clust_list, device):
18
        super(ClustDistLayer, self).__init__()
19
        self.centroids = Variable(centroids).to(device)
20
        self.n_clusters = n_clusters
21
        self.clust_list = clust_list
22
23
    def forward(self, x, curr_clust_id):
24
        output = []
25
        for i in self.clust_list:
26
            if i==curr_clust_id:
27
                continue
28
            weight = 2 * (self.centroids[self.clust_list.index(curr_clust_id)] - self.centroids[self.clust_list.index(i)])
29
            bias = torch.norm(self.centroids[self.clust_list.index(curr_clust_id)], p=2) - torch.norm(self.centroids[self.clust_list.index(i)], p=2)
30
            h = torch.matmul(x, weight.T) + bias
31
            output.append(h.unsqueeze(1))
32
33
        return torch.cat(output, dim=1)
34
35
36
class ClustMinPoolLayer(nn.Module):
37
    def __init__(self, beta):
38
        super(ClustMinPoolLayer, self).__init__()
39
        self.beta = beta
40
        self.eps = 1e-10
41
42
    def forward(self, inputs):
43
        return - torch.log(torch.sum(torch.exp(inputs * -self.beta), dim=1) + self.eps)
44
45
46
class LRP(nn.Module):
47
    def __init__(self, model, X1, X2, Z, clust_ids, n_clusters, beta=1., device="cuda"):
48
        super(LRP, self).__init__()
49
        #model.freeze_model()
50
        self.model = model
51
        self.clust_ids = clust_ids
52
        self.n_clusters = n_clusters
53
        self.clust_list = np.unique(clust_ids).astype(int).tolist()
54
        self.centroids_ =torch.tensor(self.set_centroids(Z), dtype=torch.float32)
55
        self.X1_ = torch.tensor(X1, dtype=torch.float32)
56
        self.X2_ = torch.tensor(X2, dtype=torch.float32)
57
        self.Z_ = torch.tensor(Z, dtype=torch.float32)
58
        self.distLayer = ClustDistLayer(self.centroids_, n_clusters, self.clust_list, device).to(device)
59
        self.clustMinPool = ClustMinPoolLayer(beta).to(device)
60
        self.device = device
61
62
    def set_centroids(self, Z):
63
        centroids = []
64
        for i in self.clust_list:
65
            clust_Z = Z[self.clust_ids==i]
66
            curr_centroid = np.mean(clust_Z, axis=0)
67
            centroids.append(curr_centroid)
68
69
        return np.stack(centroids, axis=0)
70
71
    def clust_minpoolAct(self, X1, X2, curr_clust_id):
72
        z,_,_,_,_,_,_,_,_ = self.model.forwardAE(X1, X2)
73
        return self.clustMinPool(self.distLayer(z, curr_clust_id))
74
75
    def calc_carlini_wagner_one_vs_one(self, clust_c_id, clust_k_id, margin=1., lamda=1e2, max_iter=5000, lr=2e-3, use_abs=True):
76
        X1_0 = Variable(self.X1_[self.clust_ids==clust_c_id], requires_grad=False).to(self.device)
77
        curr_X1 = Variable(X1_0 + 1e-6, requires_grad=True).to(self.device)
78
        X2_0 = Variable(self.X2_[self.clust_ids==clust_c_id], requires_grad=False).to(self.device)
79
        curr_X2 = Variable(X2_0 + 1e-6, requires_grad=True).to(self.device)
80
        optimizer = optim.SGD([curr_X1, curr_X2], lr=lr)
81
82
        for iter in range(max_iter):
83
            clust_c_minpoolAct_tensor = self.clust_minpoolAct(curr_X1, curr_X2, clust_c_id)
84
            clust_k_minpoolAct_tensor = self.clust_minpoolAct(curr_X1, curr_X2, clust_k_id)
85
            clust_loss_tensor = margin + clust_c_minpoolAct_tensor - clust_k_minpoolAct_tensor
86
            clust_loss_tensor = torch.maximum(clust_loss_tensor, torch.zeros_like(clust_loss_tensor))
87
            clust_loss = torch.sum(clust_loss_tensor)
88
89
            norm_loss = torch.norm(curr_X1 - X1_0, p=1) + torch.norm(curr_X2 - X2_0, p=1)
90
91
            loss = clust_loss * lamda + norm_loss
92
93
            optimizer.zero_grad()
94
            loss.backward()
95
            optimizer.step()
96
97
            if (iter+1) % 50 == 0:
98
                print('Iteration {}, Total loss:{:.8f}, clust loss:{:.8f}, L1 penalty:{:.8f}'.format(iter, loss.item(), clust_loss.item(), norm_loss.item()))
99
100
        if use_abs:
101
            rel_score1 = torch.mean(torch.abs(curr_X1 - X1_0), dim=0)
102
            rel_score2 = torch.mean(torch.abs(curr_X2 - X2_0), dim=0)
103
        else:
104
            rel_score1 = torch.mean(curr_X1 - X1_0, dim=0)
105
            rel_score2 = torch.mean(curr_X2 - X2_0, dim=0)
106
        return rel_score1.data.cpu().numpy(), rel_score2.data.cpu().numpy()
107
108
    def calc_carlini_wagner_one_vs_rest(self, clust_c_id, margin=1., lamda=1e2, max_iter=5000, lr=2e-3, use_abs=True):
109
        X1_0 = Variable(self.X1_[self.clust_ids==clust_c_id], requires_grad=False).to(self.device)
110
        curr_X1 = Variable(X1_0 + 1e-6, requires_grad=True).to(self.device)
111
        X2_0 = Variable(self.X2_[self.clust_ids==clust_c_id], requires_grad=False).to(self.device)
112
        curr_X2 = Variable(X2_0 + 1e-6, requires_grad=True).to(self.device)
113
        optimizer = optim.SGD([curr_X1, curr_X2], lr=lr)
114
115
        for iter in range(max_iter):
116
            clust_rest_minpoolAct_tensor_list = []
117
            for clust_k_id in self.clust_list:
118
                if clust_k_id == clust_c_id:
119
                    continue
120
                clust_k_minpoolAct_tensor = self.clust_minpoolAct(curr_X1, curr_X2, clust_k_id)
121
                clust_rest_minpoolAct_tensor_list.append(clust_k_minpoolAct_tensor)
122
            clust_rest_minpoolAct_tensor = clust_rest_minpoolAct_tensor_list[0]
123
            for clust_k_id in range(1, len(clust_rest_minpoolAct_tensor_list)):
124
                clust_rest_minpoolAct_tensor = torch.maximum(clust_rest_minpoolAct_tensor, clust_rest_minpoolAct_tensor_list[clust_k_id])
125
126
            clust_c_minpoolAct_tensor = self.clust_minpoolAct(curr_X1, curr_X2, clust_c_id)
127
128
            clust_loss_tensor = margin + clust_c_minpoolAct_tensor - clust_rest_minpoolAct_tensor
129
            clust_loss_tensor = torch.maximum(clust_loss_tensor, torch.zeros_like(clust_loss_tensor))
130
            clust_loss = torch.sum(clust_loss_tensor)
131
132
            norm_loss = torch.norm(curr_X1 - X1_0, p=1) + torch.norm(curr_X2 - X2_0, p=1)
133
134
            loss = clust_loss * lamda + norm_loss
135
136
            optimizer.zero_grad()
137
            loss.backward()
138
            optimizer.step()
139
140
            if (iter+1) % 50 == 0:
141
                print('Iteration {}, Total loss:{:.8f}, clust loss:{:.8f}, L1 penalty:{:.8f}'.format(iter, loss.item(), clust_loss.item(), norm_loss.item()))
142
143
        if use_abs:
144
            rel_score1 = torch.mean(torch.abs(curr_X1 - X1_0), dim=0)
145
            rel_score2 = torch.mean(torch.abs(curr_X2 - X2_0), dim=0)
146
        else:
147
            rel_score1 = torch.mean(curr_X1 - X1_0, dim=0)
148
            rel_score2 = torch.mean(curr_X2 - X2_0, dim=0)
149
        return rel_score1.data.cpu().numpy(), rel_score2.data.cpu().numpy()