a b/coderpp/test/generate_faiss_index.py
1
import os
2
import sys
3
import torch
4
import numpy as np
5
from transformers import AutoTokenizer, AutoModel, AutoConfig
6
from tqdm import tqdm
7
import faiss
8
import random
9
import string
10
import time
11
import pickle
12
import gc
13
import argparse
14
15
batch_size = 64
16
device = torch.device("cuda:0")
17
18
def get_bert_embed(phrase_list, m, tok, normalize=True, summary_method="CLS", tqdm_bar=False):
19
    input_ids = []
20
    for phrase in phrase_list:
21
        input_ids.append(tok.encode_plus(
22
            phrase, max_length=32, add_special_tokens=True,
23
            truncation=True, pad_to_max_length=True)['input_ids'])
24
        # print(len(input_ids))
25
    m.eval()
26
27
    count = len(input_ids)
28
    now_count = 0
29
    output_list = []
30
    with torch.no_grad():
31
        if tqdm_bar:
32
            pbar = tqdm(total=count)
33
        while now_count < count:
34
            input_gpu_0 = torch.LongTensor(input_ids[now_count:min(
35
                now_count + batch_size, count)]).to(device)
36
            if summary_method == "CLS":
37
                embed = m(input_gpu_0)[1]
38
            if summary_method == "MEAN":
39
                embed = torch.mean(m(input_gpu_0)[0], dim=1)
40
            if normalize:
41
                embed_norm = torch.norm(
42
                    embed, p=2, dim=1, keepdim=True).clamp(min=1e-12)
43
                embed = embed / embed_norm
44
            if now_count % 1000000 == 0:
45
                if now_count != 0:
46
                    output_list.append(output.cpu().numpy())
47
                    del output
48
                    torch.cuda.empty_cache()
49
                output = embed
50
            else:
51
                output = torch.cat((output, embed), dim=0)
52
            if tqdm_bar:
53
                pbar.update(min(now_count + batch_size, count) - now_count)
54
            now_count = min(now_count + batch_size, count)
55
            del input_gpu_0
56
            torch.cuda.empty_cache()
57
        if tqdm_bar:
58
            pbar.close()
59
    output_list.append(output.cpu().numpy())
60
    del output
61
    torch.cuda.empty_cache()
62
    return np.concatenate(output_list, axis=0)
63
64
def get_KNN(embeddings, k):
65
    d = embeddings.shape[1]
66
    res = faiss.StandardGpuResources()
67
    index = faiss.IndexFlatIP(d)
68
    gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
69
    gpu_index.add(embeddings)
70
    print(gpu_index.ntotal)
71
    similarity, indices = gpu_index.search(embeddings.astype(np.float32), k)
72
    del gpu_index
73
    gc.collect()
74
    return similarity, indices
75
76
def find_new_index(indices_path, similarity_path, embedding_path, phrase2idx_path, tokenizer_name, model_name_or_path):
77
    print('start finding new index...')
78
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
79
    if model_name_or_path[-4:] == '.pth':
80
        model = torch.load(model_name_or_path).to(device)
81
    else:
82
        model = AutoModel.from_pretrained(model_name_or_path).to(device)
83
    print('start loading phrases...')
84
    with open(phrase2idx_path, 'rb') as f:
85
        phrase2idx = pickle.load(f)
86
    phrase_list = list(phrase2idx.keys())
87
    embeddings = get_bert_embed(phrase_list, model, tokenizer, summary_method="MEAN", tqdm_bar=True)
88
    del model
89
    torch.cuda.empty_cache()
90
    with open(embedding_path, 'wb') as f:
91
        np.save(f, embeddings)
92
    print('start knn')
93
    similarity, indices = get_KNN(embeddings, 30)
94
    with open(indices_path, 'wb') as f:
95
        np.save(f, indices)
96
    with open(similarity_path, 'wb') as f:
97
        np.save(f, similarity)
98
    print('done knn')
99
    return None
100
101
102
if __name__ == "__main__":
103
    parser = argparse.ArgumentParser()
104
    parser.add_argument(
105
        "--tokenizer_name",
106
        default="GanjinZero/UMLSBert_ENG",
107
        type=str,
108
        help="Path to tokenizer"
109
    )
110
    parser.add_argument(
111
        "--model_name_or_path",
112
        default="GanjinZero/UMLSBert_ENG",
113
        type=str,
114
        help="path to model"
115
    )
116
    parser.add_argument(
117
        "--save_dir",
118
        default="../use_data/",
119
        type=str,
120
        help="output dir"
121
    )
122
    parser.add_argument(
123
        "--phrase2idx_path",
124
        default="../use_data/phrase2idx.pkl",
125
        type=str,
126
        help="Path to phrase2idx file"
127
    )
128
    args = parser.parse_args()
129
    args.indices_path = os.path.join(args.save_dir, 'indices.npy')
130
    args.similarity_path = os.path.join(args.save_dir, 'similarity.npy')
131
    args.embedding_path = os.path.join(args.save_dir, 'embedding.npy')
132
    
133
    find_new_index(
134
        tokenizer_name=args.tokenizer_name,
135
        model_name_or_path=args.model_name_or_path,
136
        indices_path=args.indices_path,
137
        similarity_path=args.similarity_path,
138
        embedding_path=args.embedding_path,
139
        phrase2idx_path=args.phrase2idx_path
140
    )