Switch to side-by-side view

--- a
+++ b/coderpp/train/test_faiss.py
@@ -0,0 +1,136 @@
+import os
+import sys
+import torch
+import numpy as np
+from transformers import AutoTokenizer, AutoModel, AutoConfig
+from tqdm import tqdm
+import faiss
+import random
+import string
+import time
+import pickle
+import gc
+
+batch_size = 64
+device = torch.device("cuda:0")
+
+
+def get_bert_embed(phrase_list, m, tok, normalize=True, summary_method="CLS", tqdm_bar=False):
+    input_ids = []
+    for phrase in phrase_list:
+        input_ids.append(tok.encode_plus(
+            phrase, max_length=32, add_special_tokens=True,
+            truncation=True, pad_to_max_length=True)['input_ids'])
+        # print(len(input_ids))
+    m.eval()
+
+    count = len(input_ids)
+    now_count = 0
+    output_list = []
+    with torch.no_grad():
+        if tqdm_bar:
+            pbar = tqdm(total=count)
+        while now_count < count:
+            input_gpu_0 = torch.LongTensor(input_ids[now_count:min(
+                now_count + batch_size, count)]).to(device)
+            if summary_method == "CLS":
+                embed = m(input_gpu_0)[1]
+            if summary_method == "MEAN":
+                embed = torch.mean(m(input_gpu_0)[0], dim=1)
+            if normalize:
+                embed_norm = torch.norm(
+                    embed, p=2, dim=1, keepdim=True).clamp(min=1e-12)
+                embed = embed / embed_norm
+            if now_count % 1000000 == 0:
+                if now_count != 0:
+                    output_list.append(output.cpu().numpy())
+                    del output
+                    torch.cuda.empty_cache()
+                output = embed
+            else:
+                output = torch.cat((output, embed), dim=0)
+            if tqdm_bar:
+                pbar.update(min(now_count + batch_size, count) - now_count)
+            now_count = min(now_count + batch_size, count)
+            del input_gpu_0
+            torch.cuda.empty_cache()
+        if tqdm_bar:
+            pbar.close()
+    output_list.append(output.cpu().numpy())
+    del output
+    torch.cuda.empty_cache()
+    return np.concatenate(output_list, axis=0)
+
+def get_KNN(embeddings, k, use_multi_gpu=True):
+    if use_multi_gpu:
+        d = embeddings.shape[1]
+        index = faiss.IndexFlatIP(d)
+        gpu_index = faiss.index_cpu_to_all_gpus(index)
+        gpu_index.add(embeddings)
+    else:
+        d = embeddings.shape[1]
+        res = faiss.StandardGpuResources()
+        index = faiss.IndexFlatIP(d)
+        gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
+        gpu_index.add(embeddings)
+    print(gpu_index.ntotal)
+    similarity, indices = gpu_index.search(embeddings.astype(np.float32), k)
+    del gpu_index
+    gc.collect()
+    return similarity, indices
+
+def find_new_index(new_CODER_path, output_path, similarity_path, idx2string_path='data/idx2string.pkl', ori_CODER_path='GanjinZero/UMLSBert_ENG', use_multi_gpu=True):
+    print('start finding new index...')
+    config = AutoConfig.from_pretrained(ori_CODER_path)
+    tokenizer = AutoTokenizer.from_pretrained(ori_CODER_path)
+    print('start loading phrases...')
+    with open('data/idx2string.pkl', 'rb') as f:
+        phrase_list = list(pickle.load(f).values())
+    print('done loading phrases')
+    model = torch.load(new_CODER_path).to(device)
+    embeddings = get_bert_embed(phrase_list, model, tokenizer, summary_method="MEAN", tqdm_bar=True)
+    del model
+    torch.cuda.empty_cache()
+    print('start knn')
+    similarity, indices = get_KNN(embeddings, 30, use_multi_gpu)  
+    # similarity = np.zeros((len(phrase_list), 30))
+    # indices = similarity
+    with open(output_path, 'wb') as f:
+        np.save(f, indices)
+    with open(similarity_path, 'wb') as f:
+        np.save(f, similarity)
+    print('done knn')
+    return None
+
+
+if __name__ == "__main__":
+    filename = "GanjinZero/UMLSBert_ENG"
+    config = AutoConfig.from_pretrained(filename)
+    tokenizer = AutoTokenizer.from_pretrained(filename)
+    print('start loading phrases...')
+    with open('data/idx2string.pkl', 'rb') as f:
+        phrase_list = list(pickle.load(f).values())
+    print('done loading phrases')
+    # model = AutoModel.from_pretrained(
+    #     filename,
+    #     config=config).to(device)
+    model = torch.load('output_testttt/last_model.pth').to(device)
+    start = time.time()
+    print('start testing...')
+    embeddings = get_bert_embed(phrase_list[:100], model, tokenizer, summary_method="MEAN", tqdm_bar=True)
+    print(embeddings.shape)
+    # with open('data/embeddings.npy', 'wb') as f:
+    #     np.save(f, embeddings)
+    # print('done testing')
+    # del model
+    # torch.cuda.empty_cache()
+    # embeddings = np.load('data/embeddings.npy')
+    # print('start knn...')
+    # similarity, indices = get_KNN(embeddings, 30, use_multi_gpu=True)  
+    # with open('data/similarity.npy', 'wb') as f:
+    #     np.save(f, similarity)
+    # with open('data/indices.npy', 'wb') as f:
+    #     np.save(f, indices)
+    # print('done knn')
+    # end = time.time()
+    # print(end - start, 's')
\ No newline at end of file