--- a +++ b/src/LRP.py @@ -0,0 +1,149 @@ +import re +import torch +from torch.functional import norm +import torch.nn as nn +from torch.autograd import Variable +from torch.nn import Parameter +import torch.nn.functional as F +import torch.optim as optim +from torch.optim.lr_scheduler import * +from torch.utils.data import DataLoader, TensorDataset +from torch.nn.utils import clip_grad_norm_ +import numpy as np +import math, os + + +class ClustDistLayer(nn.Module): + def __init__(self, centroids, n_clusters, clust_list, device): + super(ClustDistLayer, self).__init__() + self.centroids = Variable(centroids).to(device) + self.n_clusters = n_clusters + self.clust_list = clust_list + + def forward(self, x, curr_clust_id): + output = [] + for i in self.clust_list: + if i==curr_clust_id: + continue + weight = 2 * (self.centroids[self.clust_list.index(curr_clust_id)] - self.centroids[self.clust_list.index(i)]) + bias = torch.norm(self.centroids[self.clust_list.index(curr_clust_id)], p=2) - torch.norm(self.centroids[self.clust_list.index(i)], p=2) + h = torch.matmul(x, weight.T) + bias + output.append(h.unsqueeze(1)) + + return torch.cat(output, dim=1) + + +class ClustMinPoolLayer(nn.Module): + def __init__(self, beta): + super(ClustMinPoolLayer, self).__init__() + self.beta = beta + self.eps = 1e-10 + + def forward(self, inputs): + return - torch.log(torch.sum(torch.exp(inputs * -self.beta), dim=1) + self.eps) + + +class LRP(nn.Module): + def __init__(self, model, X1, X2, Z, clust_ids, n_clusters, beta=1., device="cuda"): + super(LRP, self).__init__() + #model.freeze_model() + self.model = model + self.clust_ids = clust_ids + self.n_clusters = n_clusters + self.clust_list = np.unique(clust_ids).astype(int).tolist() + self.centroids_ =torch.tensor(self.set_centroids(Z), dtype=torch.float32) + self.X1_ = torch.tensor(X1, dtype=torch.float32) + self.X2_ = torch.tensor(X2, dtype=torch.float32) + self.Z_ = torch.tensor(Z, dtype=torch.float32) + self.distLayer = ClustDistLayer(self.centroids_, n_clusters, self.clust_list, device).to(device) + self.clustMinPool = ClustMinPoolLayer(beta).to(device) + self.device = device + + def set_centroids(self, Z): + centroids = [] + for i in self.clust_list: + clust_Z = Z[self.clust_ids==i] + curr_centroid = np.mean(clust_Z, axis=0) + centroids.append(curr_centroid) + + return np.stack(centroids, axis=0) + + def clust_minpoolAct(self, X1, X2, curr_clust_id): + z,_,_,_,_,_,_,_,_ = self.model.forwardAE(X1, X2) + return self.clustMinPool(self.distLayer(z, curr_clust_id)) + + def calc_carlini_wagner_one_vs_one(self, clust_c_id, clust_k_id, margin=1., lamda=1e2, max_iter=5000, lr=2e-3, use_abs=True): + X1_0 = Variable(self.X1_[self.clust_ids==clust_c_id], requires_grad=False).to(self.device) + curr_X1 = Variable(X1_0 + 1e-6, requires_grad=True).to(self.device) + X2_0 = Variable(self.X2_[self.clust_ids==clust_c_id], requires_grad=False).to(self.device) + curr_X2 = Variable(X2_0 + 1e-6, requires_grad=True).to(self.device) + optimizer = optim.SGD([curr_X1, curr_X2], lr=lr) + + for iter in range(max_iter): + clust_c_minpoolAct_tensor = self.clust_minpoolAct(curr_X1, curr_X2, clust_c_id) + clust_k_minpoolAct_tensor = self.clust_minpoolAct(curr_X1, curr_X2, clust_k_id) + clust_loss_tensor = margin + clust_c_minpoolAct_tensor - clust_k_minpoolAct_tensor + clust_loss_tensor = torch.maximum(clust_loss_tensor, torch.zeros_like(clust_loss_tensor)) + clust_loss = torch.sum(clust_loss_tensor) + + norm_loss = torch.norm(curr_X1 - X1_0, p=1) + torch.norm(curr_X2 - X2_0, p=1) + + loss = clust_loss * lamda + norm_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if (iter+1) % 50 == 0: + print('Iteration {}, Total loss:{:.8f}, clust loss:{:.8f}, L1 penalty:{:.8f}'.format(iter, loss.item(), clust_loss.item(), norm_loss.item())) + + if use_abs: + rel_score1 = torch.mean(torch.abs(curr_X1 - X1_0), dim=0) + rel_score2 = torch.mean(torch.abs(curr_X2 - X2_0), dim=0) + else: + rel_score1 = torch.mean(curr_X1 - X1_0, dim=0) + rel_score2 = torch.mean(curr_X2 - X2_0, dim=0) + return rel_score1.data.cpu().numpy(), rel_score2.data.cpu().numpy() + + def calc_carlini_wagner_one_vs_rest(self, clust_c_id, margin=1., lamda=1e2, max_iter=5000, lr=2e-3, use_abs=True): + X1_0 = Variable(self.X1_[self.clust_ids==clust_c_id], requires_grad=False).to(self.device) + curr_X1 = Variable(X1_0 + 1e-6, requires_grad=True).to(self.device) + X2_0 = Variable(self.X2_[self.clust_ids==clust_c_id], requires_grad=False).to(self.device) + curr_X2 = Variable(X2_0 + 1e-6, requires_grad=True).to(self.device) + optimizer = optim.SGD([curr_X1, curr_X2], lr=lr) + + for iter in range(max_iter): + clust_rest_minpoolAct_tensor_list = [] + for clust_k_id in self.clust_list: + if clust_k_id == clust_c_id: + continue + clust_k_minpoolAct_tensor = self.clust_minpoolAct(curr_X1, curr_X2, clust_k_id) + clust_rest_minpoolAct_tensor_list.append(clust_k_minpoolAct_tensor) + clust_rest_minpoolAct_tensor = clust_rest_minpoolAct_tensor_list[0] + for clust_k_id in range(1, len(clust_rest_minpoolAct_tensor_list)): + clust_rest_minpoolAct_tensor = torch.maximum(clust_rest_minpoolAct_tensor, clust_rest_minpoolAct_tensor_list[clust_k_id]) + + clust_c_minpoolAct_tensor = self.clust_minpoolAct(curr_X1, curr_X2, clust_c_id) + + clust_loss_tensor = margin + clust_c_minpoolAct_tensor - clust_rest_minpoolAct_tensor + clust_loss_tensor = torch.maximum(clust_loss_tensor, torch.zeros_like(clust_loss_tensor)) + clust_loss = torch.sum(clust_loss_tensor) + + norm_loss = torch.norm(curr_X1 - X1_0, p=1) + torch.norm(curr_X2 - X2_0, p=1) + + loss = clust_loss * lamda + norm_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if (iter+1) % 50 == 0: + print('Iteration {}, Total loss:{:.8f}, clust loss:{:.8f}, L1 penalty:{:.8f}'.format(iter, loss.item(), clust_loss.item(), norm_loss.item())) + + if use_abs: + rel_score1 = torch.mean(torch.abs(curr_X1 - X1_0), dim=0) + rel_score2 = torch.mean(torch.abs(curr_X2 - X2_0), dim=0) + else: + rel_score1 = torch.mean(curr_X1 - X1_0, dim=0) + rel_score2 = torch.mean(curr_X2 - X2_0, dim=0) + return rel_score1.data.cpu().numpy(), rel_score2.data.cpu().numpy()