Switch to unified view

a b/coderpp/train/generate_faiss_index.py
1
import os
2
import sys
3
import torch
4
import numpy as np
5
from transformers import AutoTokenizer, AutoModel, AutoConfig
6
from tqdm import tqdm
7
import faiss
8
import random
9
import string
10
import time
11
import pickle
12
import gc
13
import argparse
14
15
batch_size = 64
16
device = torch.device("cuda:0")
17
18
def get_bert_embed(phrase_list, m, tok, normalize=True, summary_method="CLS", tqdm_bar=False):
19
    input_ids = []
20
    for phrase in phrase_list:
21
        input_ids.append(tok.encode_plus(
22
            phrase, max_length=32, add_special_tokens=True,
23
            truncation=True, pad_to_max_length=True)['input_ids'])
24
        # print(len(input_ids))
25
    m.eval()
26
27
    count = len(input_ids)
28
    now_count = 0
29
    output_list = []
30
    with torch.no_grad():
31
        if tqdm_bar:
32
            pbar = tqdm(total=count)
33
        while now_count < count:
34
            input_gpu_0 = torch.LongTensor(input_ids[now_count:min(
35
                now_count + batch_size, count)]).to(device)
36
            if summary_method == "CLS":
37
                embed = m(input_gpu_0)[1]
38
            if summary_method == "MEAN":
39
                embed = torch.mean(m(input_gpu_0)[0], dim=1)
40
            if normalize:
41
                embed_norm = torch.norm(
42
                    embed, p=2, dim=1, keepdim=True).clamp(min=1e-12)
43
                embed = embed / embed_norm
44
            if now_count % 1000000 == 0:
45
                if now_count != 0:
46
                    output_list.append(output.cpu().numpy())
47
                    del output
48
                    torch.cuda.empty_cache()
49
                output = embed
50
            else:
51
                output = torch.cat((output, embed), dim=0)
52
            if tqdm_bar:
53
                pbar.update(min(now_count + batch_size, count) - now_count)
54
            now_count = min(now_count + batch_size, count)
55
            del input_gpu_0
56
            torch.cuda.empty_cache()
57
        if tqdm_bar:
58
            pbar.close()
59
    output_list.append(output.cpu().numpy())
60
    del output
61
    torch.cuda.empty_cache()
62
    return np.concatenate(output_list, axis=0)
63
64
def get_KNN(embeddings, k):
65
    d = embeddings.shape[1]
66
    res = faiss.StandardGpuResources()
67
    index = faiss.IndexFlatIP(d)
68
    gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
69
    gpu_index.add(embeddings)
70
    print(gpu_index.ntotal)
71
    similarity, indices = gpu_index.search(embeddings.astype(np.float32), k)
72
    del gpu_index
73
    gc.collect()
74
    return similarity, indices
75
76
def find_new_index(indices_path, similarity_path, embedding_path, phrase2idx_path, ori_CODER_path='GanjinZero/UMLSBert_ENG'):
77
    print('start finding new index...')
78
    config = AutoConfig.from_pretrained(ori_CODER_path)
79
    tokenizer = AutoTokenizer.from_pretrained(ori_CODER_path)
80
    model = AutoModel.from_pretrained(ori_CODER_path).to(device)
81
    print('start loading phrases...')
82
    with open(phrase2idx_path, 'rb') as f:
83
        phrase2idx = pickle.load(f)
84
    phrase_list = list(phrase2idx.keys())
85
    embeddings = get_bert_embed(phrase_list, model, tokenizer, summary_method="MEAN", tqdm_bar=True)
86
    del model
87
    torch.cuda.empty_cache()
88
    with open(embedding_path, 'wb') as f:
89
        np.save(f, embeddings)
90
    print('start knn')
91
    similarity, indices = get_KNN(embeddings, 30)
92
    with open(indices_path, 'wb') as f:
93
        np.save(f, indices)
94
    with open(similarity_path, 'wb') as f:
95
        np.save(f, similarity)
96
    print('done knn')
97
    return None
98
99
100
if __name__ == "__main__":
101
    parser = argparse.ArgumentParser()
102
    parser.add_argument(
103
        "--CODER_name",
104
        default="GanjinZero/UMLSBert_ENG",
105
        type=str,
106
        help="Path to CODER"
107
    )
108
    parser.add_argument(
109
        "--save_dir",
110
        default="../use_data/",
111
        type=str,
112
        help="output dir"
113
    )
114
    parser.add_argument(
115
        "--phrase2idx_path",
116
        default="../use_data/phrase2idx.pkl",
117
        type=str,
118
        help="Path to phrase2idx file"
119
    )
120
    args = parser.parse_args()
121
    args.indices_path = os.path.join(args.save_dir, 'indices.npy')
122
    args.similarity_path = os.path.join(args.save_dir, 'similarity.npy')
123
    args.embedding_path = os.path.join(args.save_dir, 'embedding.npy')
124
    
125
    find_new_index(
126
        ori_CODER_path=args.CODER_name,
127
        indices_path=args.indices_path,
128
        similarity_path=args.similarity_path,
129
        embedding_path=args.embedding_path,
130
        phrase2idx_path=args.phrase2idx_path
131
    )