|
a |
|
b/src/LRP.py |
|
|
1 |
import re |
|
|
2 |
import torch |
|
|
3 |
from torch.functional import norm |
|
|
4 |
import torch.nn as nn |
|
|
5 |
from torch.autograd import Variable |
|
|
6 |
from torch.nn import Parameter |
|
|
7 |
import torch.nn.functional as F |
|
|
8 |
import torch.optim as optim |
|
|
9 |
from torch.optim.lr_scheduler import * |
|
|
10 |
from torch.utils.data import DataLoader, TensorDataset |
|
|
11 |
from torch.nn.utils import clip_grad_norm_ |
|
|
12 |
import numpy as np |
|
|
13 |
import math, os |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
class ClustDistLayer(nn.Module): |
|
|
17 |
def __init__(self, centroids, n_clusters, clust_list, device): |
|
|
18 |
super(ClustDistLayer, self).__init__() |
|
|
19 |
self.centroids = Variable(centroids).to(device) |
|
|
20 |
self.n_clusters = n_clusters |
|
|
21 |
self.clust_list = clust_list |
|
|
22 |
|
|
|
23 |
def forward(self, x, curr_clust_id): |
|
|
24 |
output = [] |
|
|
25 |
for i in self.clust_list: |
|
|
26 |
if i==curr_clust_id: |
|
|
27 |
continue |
|
|
28 |
weight = 2 * (self.centroids[self.clust_list.index(curr_clust_id)] - self.centroids[self.clust_list.index(i)]) |
|
|
29 |
bias = torch.norm(self.centroids[self.clust_list.index(curr_clust_id)], p=2) - torch.norm(self.centroids[self.clust_list.index(i)], p=2) |
|
|
30 |
h = torch.matmul(x, weight.T) + bias |
|
|
31 |
output.append(h.unsqueeze(1)) |
|
|
32 |
|
|
|
33 |
return torch.cat(output, dim=1) |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
class ClustMinPoolLayer(nn.Module): |
|
|
37 |
def __init__(self, beta): |
|
|
38 |
super(ClustMinPoolLayer, self).__init__() |
|
|
39 |
self.beta = beta |
|
|
40 |
self.eps = 1e-10 |
|
|
41 |
|
|
|
42 |
def forward(self, inputs): |
|
|
43 |
return - torch.log(torch.sum(torch.exp(inputs * -self.beta), dim=1) + self.eps) |
|
|
44 |
|
|
|
45 |
|
|
|
46 |
class LRP(nn.Module): |
|
|
47 |
def __init__(self, model, X1, X2, Z, clust_ids, n_clusters, beta=1., device="cuda"): |
|
|
48 |
super(LRP, self).__init__() |
|
|
49 |
#model.freeze_model() |
|
|
50 |
self.model = model |
|
|
51 |
self.clust_ids = clust_ids |
|
|
52 |
self.n_clusters = n_clusters |
|
|
53 |
self.clust_list = np.unique(clust_ids).astype(int).tolist() |
|
|
54 |
self.centroids_ =torch.tensor(self.set_centroids(Z), dtype=torch.float32) |
|
|
55 |
self.X1_ = torch.tensor(X1, dtype=torch.float32) |
|
|
56 |
self.X2_ = torch.tensor(X2, dtype=torch.float32) |
|
|
57 |
self.Z_ = torch.tensor(Z, dtype=torch.float32) |
|
|
58 |
self.distLayer = ClustDistLayer(self.centroids_, n_clusters, self.clust_list, device).to(device) |
|
|
59 |
self.clustMinPool = ClustMinPoolLayer(beta).to(device) |
|
|
60 |
self.device = device |
|
|
61 |
|
|
|
62 |
def set_centroids(self, Z): |
|
|
63 |
centroids = [] |
|
|
64 |
for i in self.clust_list: |
|
|
65 |
clust_Z = Z[self.clust_ids==i] |
|
|
66 |
curr_centroid = np.mean(clust_Z, axis=0) |
|
|
67 |
centroids.append(curr_centroid) |
|
|
68 |
|
|
|
69 |
return np.stack(centroids, axis=0) |
|
|
70 |
|
|
|
71 |
def clust_minpoolAct(self, X1, X2, curr_clust_id): |
|
|
72 |
z,_,_,_,_,_,_,_,_ = self.model.forwardAE(X1, X2) |
|
|
73 |
return self.clustMinPool(self.distLayer(z, curr_clust_id)) |
|
|
74 |
|
|
|
75 |
def calc_carlini_wagner_one_vs_one(self, clust_c_id, clust_k_id, margin=1., lamda=1e2, max_iter=5000, lr=2e-3, use_abs=True): |
|
|
76 |
X1_0 = Variable(self.X1_[self.clust_ids==clust_c_id], requires_grad=False).to(self.device) |
|
|
77 |
curr_X1 = Variable(X1_0 + 1e-6, requires_grad=True).to(self.device) |
|
|
78 |
X2_0 = Variable(self.X2_[self.clust_ids==clust_c_id], requires_grad=False).to(self.device) |
|
|
79 |
curr_X2 = Variable(X2_0 + 1e-6, requires_grad=True).to(self.device) |
|
|
80 |
optimizer = optim.SGD([curr_X1, curr_X2], lr=lr) |
|
|
81 |
|
|
|
82 |
for iter in range(max_iter): |
|
|
83 |
clust_c_minpoolAct_tensor = self.clust_minpoolAct(curr_X1, curr_X2, clust_c_id) |
|
|
84 |
clust_k_minpoolAct_tensor = self.clust_minpoolAct(curr_X1, curr_X2, clust_k_id) |
|
|
85 |
clust_loss_tensor = margin + clust_c_minpoolAct_tensor - clust_k_minpoolAct_tensor |
|
|
86 |
clust_loss_tensor = torch.maximum(clust_loss_tensor, torch.zeros_like(clust_loss_tensor)) |
|
|
87 |
clust_loss = torch.sum(clust_loss_tensor) |
|
|
88 |
|
|
|
89 |
norm_loss = torch.norm(curr_X1 - X1_0, p=1) + torch.norm(curr_X2 - X2_0, p=1) |
|
|
90 |
|
|
|
91 |
loss = clust_loss * lamda + norm_loss |
|
|
92 |
|
|
|
93 |
optimizer.zero_grad() |
|
|
94 |
loss.backward() |
|
|
95 |
optimizer.step() |
|
|
96 |
|
|
|
97 |
if (iter+1) % 50 == 0: |
|
|
98 |
print('Iteration {}, Total loss:{:.8f}, clust loss:{:.8f}, L1 penalty:{:.8f}'.format(iter, loss.item(), clust_loss.item(), norm_loss.item())) |
|
|
99 |
|
|
|
100 |
if use_abs: |
|
|
101 |
rel_score1 = torch.mean(torch.abs(curr_X1 - X1_0), dim=0) |
|
|
102 |
rel_score2 = torch.mean(torch.abs(curr_X2 - X2_0), dim=0) |
|
|
103 |
else: |
|
|
104 |
rel_score1 = torch.mean(curr_X1 - X1_0, dim=0) |
|
|
105 |
rel_score2 = torch.mean(curr_X2 - X2_0, dim=0) |
|
|
106 |
return rel_score1.data.cpu().numpy(), rel_score2.data.cpu().numpy() |
|
|
107 |
|
|
|
108 |
def calc_carlini_wagner_one_vs_rest(self, clust_c_id, margin=1., lamda=1e2, max_iter=5000, lr=2e-3, use_abs=True): |
|
|
109 |
X1_0 = Variable(self.X1_[self.clust_ids==clust_c_id], requires_grad=False).to(self.device) |
|
|
110 |
curr_X1 = Variable(X1_0 + 1e-6, requires_grad=True).to(self.device) |
|
|
111 |
X2_0 = Variable(self.X2_[self.clust_ids==clust_c_id], requires_grad=False).to(self.device) |
|
|
112 |
curr_X2 = Variable(X2_0 + 1e-6, requires_grad=True).to(self.device) |
|
|
113 |
optimizer = optim.SGD([curr_X1, curr_X2], lr=lr) |
|
|
114 |
|
|
|
115 |
for iter in range(max_iter): |
|
|
116 |
clust_rest_minpoolAct_tensor_list = [] |
|
|
117 |
for clust_k_id in self.clust_list: |
|
|
118 |
if clust_k_id == clust_c_id: |
|
|
119 |
continue |
|
|
120 |
clust_k_minpoolAct_tensor = self.clust_minpoolAct(curr_X1, curr_X2, clust_k_id) |
|
|
121 |
clust_rest_minpoolAct_tensor_list.append(clust_k_minpoolAct_tensor) |
|
|
122 |
clust_rest_minpoolAct_tensor = clust_rest_minpoolAct_tensor_list[0] |
|
|
123 |
for clust_k_id in range(1, len(clust_rest_minpoolAct_tensor_list)): |
|
|
124 |
clust_rest_minpoolAct_tensor = torch.maximum(clust_rest_minpoolAct_tensor, clust_rest_minpoolAct_tensor_list[clust_k_id]) |
|
|
125 |
|
|
|
126 |
clust_c_minpoolAct_tensor = self.clust_minpoolAct(curr_X1, curr_X2, clust_c_id) |
|
|
127 |
|
|
|
128 |
clust_loss_tensor = margin + clust_c_minpoolAct_tensor - clust_rest_minpoolAct_tensor |
|
|
129 |
clust_loss_tensor = torch.maximum(clust_loss_tensor, torch.zeros_like(clust_loss_tensor)) |
|
|
130 |
clust_loss = torch.sum(clust_loss_tensor) |
|
|
131 |
|
|
|
132 |
norm_loss = torch.norm(curr_X1 - X1_0, p=1) + torch.norm(curr_X2 - X2_0, p=1) |
|
|
133 |
|
|
|
134 |
loss = clust_loss * lamda + norm_loss |
|
|
135 |
|
|
|
136 |
optimizer.zero_grad() |
|
|
137 |
loss.backward() |
|
|
138 |
optimizer.step() |
|
|
139 |
|
|
|
140 |
if (iter+1) % 50 == 0: |
|
|
141 |
print('Iteration {}, Total loss:{:.8f}, clust loss:{:.8f}, L1 penalty:{:.8f}'.format(iter, loss.item(), clust_loss.item(), norm_loss.item())) |
|
|
142 |
|
|
|
143 |
if use_abs: |
|
|
144 |
rel_score1 = torch.mean(torch.abs(curr_X1 - X1_0), dim=0) |
|
|
145 |
rel_score2 = torch.mean(torch.abs(curr_X2 - X2_0), dim=0) |
|
|
146 |
else: |
|
|
147 |
rel_score1 = torch.mean(curr_X1 - X1_0, dim=0) |
|
|
148 |
rel_score2 = torch.mean(curr_X2 - X2_0, dim=0) |
|
|
149 |
return rel_score1.data.cpu().numpy(), rel_score2.data.cpu().numpy() |