--- a +++ b/projection.py @@ -0,0 +1,143 @@ + +import torch +import numpy as np +from sklearn.cluster import KMeans +from scipy.spatial import distance +import consts + +def construct_train_matrix(SNN_model, trainLoader_simple): + """ + Embed the training data using the trained SNN model. + """ + + # Get intermidate layer output. + # Source: https://discuss.pytorch.org/t/how-can-i-extract-intermediate-layer-output-from-loaded-cnn-model/77301/2 + activation = {} + + def get_activation(name): + def hook(model, input, output): + activation[name] = output.detach() + + return hook + + # SNN_model.bert_arch.fc1.register_forward_hook(get_activation('fc1')) + SNN_model.bert_arch.fc.register_forward_hook(get_activation('fc')) + + train_matrix = [] + print("num batches:", len(trainLoader_simple)) + with torch.no_grad(): + for i, batch in enumerate(trainLoader_simple): + print(i, "/", len(trainLoader_simple), " batches") + seq1, mask1, label1 = batch + if consts.device == 'cuda': + seq1, mask1, label1 = seq1.to(consts.device), mask1.to(consts.device), label1.to(consts.device) + + SNN_model.eval() + output = SNN_model(seq1, seq1, mask1, mask1) + train_matrix.append(activation[ + 'fc'].cpu().numpy()) # activation['fc1'] return a tensor in cuda with size (batch_size, embedding_dim), so we move it to cpu, than to numpy array. + + return np.vstack( + train_matrix) # we combine all the batches, so now we return matrix of size (num_samples_train, embedding_dim) + + +def extract_prototypes(k, trainLoader_simple, train_labels, train_matrix): + """ + Compute kc (= k/num_classes_train) prototypes for each class in the trainset. + (if k % num_classes_train != 0 then take the highest k0 <= k which is divisable by num_classes_train) + + :param k: + :param train_matrix: size (num_samples_train, embedding_dim) + :return: + """ + train_labels = list(train_labels) + train_dataset = trainLoader_simple.dataset # contains triples of (seq, mask, label) + + # construct a hash table, each key is a class of diagnosis + # and the value is a list of the indexs of the sentences which belong to this class + hash_table = {} # format: {"diagnosis" : [i1, i2, ...]} + for i in range(len(train_labels)): + lbl = train_labels[i] + if lbl in hash_table: + hash_table[lbl].append(i) + else: + hash_table[lbl] = [i] + + # Create prototypess + prototypes_list = {diagnosis: [] for diagnosis in hash_table.keys()} + num_classes_train = len(hash_table) + assert k >= num_classes_train, "k should be greater than the numbrer of uniqe labels in the train set'" + kc = int(k / num_classes_train) + print("kc:", kc) + + for diagnosis in hash_table.keys(): + + print("diagnosis:", diagnosis) + + if len(hash_table[ + diagnosis]) <= 1: # if there is only a single sentence in some diagnosis sentences list - take it as the prototype of this class + prototypes_list += list(train_matrix[hash_table[diagnosis]]) + + else: + # fit on all sentences which belongs to the same class (diagnosis) + kmeans = KMeans(n_clusters=kc, init='k-means++').fit(train_matrix[hash_table[diagnosis]]) + # extract for each centroid the closest real sample, and add it as a prototype + for centroid in kmeans.cluster_centers_: + # print(train_matrix[hash_table[diagnosis]].shape) + best_match_index = None + best_match_dist = float('inf') + for sentence_index in hash_table[diagnosis]: + # print(sentence_index) + embedded_sent = train_matrix[sentence_index] + dist = distance.euclidean(centroid, embedded_sent) + # print("dist:", best_match_dist) + if dist < best_match_dist: + best_match_dist = dist + best_match_index = sentence_index + # print(best_match_index) + prototypes_list[diagnosis].append(train_dataset[best_match_index]) + + return prototypes_list + +def project_to_dissimilarity_space(dataLoader, SNN_model, prototypes_list): + """ + + Parameters + ---------- + dataLoader : + SNN_model : + + + Returns + ------- + projected_data : numpy array of shape (num_samples_data, projection_dim) + """ + + projected_data = [] + with torch.no_grad(): + + for batch in dataLoader: + print("****new batch***") + + projected_sentence = [] + seq1, mask1, label1 = batch + if consts.device == 'cuda': + seq1, mask1, label1 = seq1.to(consts.device), mask1.to(consts.device), label1.to(consts.device) + + for diagnosis in prototypes_list: + # print("diagnosis:", diagnosis) + for centroid in prototypes_list[diagnosis]: # centroid contains a triple of (seq, mask, label) + seq2, mask2 = centroid[0].repeat(seq1.shape[0], 1), centroid[1].repeat(mask1.shape[0], + 1) # we stack to seq2 and mask2 replications of them such that they will fit the batch size of seq1 and mask2 + if consts.device == 'cuda': + seq2, mask2 = seq2.to(consts.device), mask2.to(consts.device) + + SNN_model.eval() + distance = SNN_model(seq1, seq2, mask1, mask2) + projected_sentence.append(distance.squeeze().cpu().numpy()) + + projected_data.append(np.array(projected_sentence).T) + # break # TODO delete + + projected_data = np.vstack(projected_data) + return projected_data \ No newline at end of file