Switch to unified view

a b/test/embeddings_reimplement/mcsm.py
1
from transformers import AutoConfig, AutoModel, AutoTokenizer
2
import torch
3
from tqdm import tqdm
4
import numpy as np
5
import sys
6
sys.path.append("../../pretrain")
7
from load_umls import UMLS
8
from nltk.tokenize import word_tokenize
9
import ipdb
10
import os
11
12
13
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
14
log_list = 1 / np.log2(list(range(2, 1001, 1)))
15
16
batch_size = 512
17
max_seq_length = 32
18
19
# umls = UMLS("../../umls", source_range='SNOMEDCT_US')
20
t_list = ['Pharmacologic Substance', 'Disease or Syndrome',
21
          'Neoplastic Process', 'Clinical Drug', 'Finding', 'Injury or Poisoning']
22
23
24
def mcsm(embedding_list, embedding_type_list, type_list=t_list, k=40, lang_range=['ENG'], check_intersection=False):
25
    if check_intersection:
26
        if not os.path.exists("intersection.txt"):
27
            intersection_cui = get_intersection(
28
                embedding_list, embedding_type_list)
29
            with open("intersection.txt", "w", encoding="utf-8") as f:
30
                for cui in intersection_cui:
31
                    f.write(cui.strip() + "\n")
32
        else:
33
            with open("intersection.txt", "r", encoding="utf-8") as f:
34
                lines = f.readlines()
35
            intersection_cui = [line.strip() for line in lines]
36
37
    umls = UMLS("../../umls", source_range='SNOMEDCT_US',
38
                lang_range=lang_range)
39
    if check_intersection:
40
        cui_list = [cui for cui in intersection_cui
41
                    if cui in umls.cui2sty and umls.cui2sty[cui] in type_list]
42
    else:
43
        cui_list = [cui for cui, sty in umls.cui2sty.items()
44
                    if sty in type_list]
45
    opt = []
46
    for index, embedding in enumerate(embedding_list):
47
        if embedding_type_list[index].lower() == "cui":
48
            opt.append(mcsm_cui(embedding, umls, cui_list, type_list, k))
49
        if embedding_type_list[index].lower() == "word":
50
            opt.append(mcsm_word(embedding, umls, cui_list, type_list, k))
51
        if embedding_type_list[index].lower() == "bert":
52
            opt.append(mcsm_bert(embedding, umls, cui_list,
53
                                 type_list, k, summary_method="MEAN"))
54
            opt.append(mcsm_bert(embedding, umls, cui_list,
55
                                 type_list, k, summary_method="CLS"))
56
    return opt
57
58
59
def mcsm_cui(cui_embedding, umls, cui_list, type_list, k=40):
60
    w, _ = load_embedding(cui_embedding)
61
    if cui_list is None:
62
        cui_list = list(w.keys())
63
        print(f"Check cui count:{len(cui_list)}")
64
    else:
65
        print(f"All cui count:{len(cui_list)}")
66
        cui_list = list(set(w.keys()).intersection(set(cui_list)))
67
        print(f"Check cui count:{len(cui_list)}")
68
69
    term_embedding = np.array([w[cui] for cui in cui_list])
70
    term_type = [umls.cui2sty[cui] for cui in cui_list]
71
72
    return calculate_mcsm(term_embedding, term_type, type_list, k=k)
73
74
75
def mcsm_word(word_embedding, umls, cui_list, type_list, k=40):
76
    w, dim = load_embedding(word_embedding)
77
78
    print(f"All cui count:{len(cui_list)}")
79
    cui_str = [[word for word in word_tokenize(
80
        list(umls.cui2str[cui])[0]) if word in w] for cui in cui_list]
81
82
    check_count = 0
83
    term_type = []
84
    for index, cui in tqdm(enumerate(cui_str)):
85
        if len(cui) > 0:
86
            term_type.append(umls.cui2sty[cui_list[index]])
87
88
            tmp_emb = np.zeros((dim))
89
            for word in cui:
90
                tmp_emb += w[word]
91
92
            if check_count == 0:
93
                term_embedding = tmp_emb
94
            else:
95
                term_embedding = np.concatenate(
96
                    (term_embedding, tmp_emb), axis=0)
97
            check_count += 1
98
            """
99
            if check_count > 500:
100
                break
101
            """
102
    term_embedding = term_embedding.reshape((-1, dim))
103
104
    print(f"Check cui count:{check_count}")
105
106
    return calculate_mcsm(term_embedding, term_type, type_list, k=k)
107
108
109
def mcsm_bert(bert_embedding, umls, cui_list, type_list, k=40, summary_method="MEAN"):
110
    print(f"Check cui count:{len(cui_list)}")
111
    model, tokenizer = load_bert(bert_embedding)
112
    model.eval()
113
114
    input_ids = []
115
    for cui in tqdm(cui_list):
116
        input_ids.append(tokenizer.encode_plus(
117
            list(umls.cui2str[cui])[
118
                0], max_length=max_seq_length, add_special_tokens=True,
119
            truncation=True, pad_to_max_length=True)['input_ids'])
120
121
    count = len(input_ids)
122
    now_count = 0
123
    with tqdm(total=count) as pbar:
124
        with torch.no_grad():
125
            while now_count < count:
126
                input_gpu_0 = torch.LongTensor(input_ids[now_count:min(
127
                    now_count + batch_size, count)]).to(device)
128
                if summary_method == "CLS":
129
                    embed = model(input_gpu_0)[1]
130
                if summary_method == "MEAN":
131
                    embed = torch.mean(model(input_gpu_0)[0], dim=1)
132
                embed_np = embed.cpu().detach().numpy()
133
                if now_count == 0:
134
                    term_embedding = embed_np
135
                else:
136
                    term_embedding = np.concatenate((term_embedding, embed_np), axis=0)
137
                update = min(now_count + batch_size, count) - now_count
138
                now_count = now_count + update
139
                pbar.update(update)
140
141
    term_type = [umls.cui2sty[cui] for cui in cui_list]
142
    return calculate_mcsm(term_embedding, term_type, type_list, k=k)
143
144
145
def summary(opt):
146
    new_opt = {k: (np.mean(v), np.std(v)) for k, v in opt.items()}
147
    return new_opt
148
149
150
def calculate_mcsm(term_embedding, term_type, target_type_list, k):
151
    # term_embedding: term_count * embedding_dim
152
    # term_type: term_count
153
    term_embedding = torch.FloatTensor(term_embedding).to(device)
154
    embedding_norm = torch.norm(
155
        term_embedding, p=2, dim=1, keepdim=True).clamp(min=1e-12)
156
    term_embedding = torch.div(term_embedding, embedding_norm)
157
    del embedding_norm
158
    output = {target_type: [] for target_type in target_type_list}
159
    for index, t in tqdm(enumerate(term_type)):
160
        if t in target_type_list:
161
            now = term_embedding[index]
162
            score = 0.0
163
            similarity = torch.matmul(term_embedding, now)
164
            # The most similar term is itself
165
            _, indices = torch.topk(similarity, k=k + 1)
166
            for i in range(1, k + 1, 1):
167
                if term_type[indices[i]] == t:
168
                    score += log_list[i - 1]
169
            output[t].append(score)
170
    del term_embedding
171
172
    output = summary(output)
173
    print(output)
174
    return output
175
176
177
def load_embedding(filename):
178
    print(filename)
179
    if filename.find('bin') >= 0:
180
        from gensim import models
181
        W = models.KeyedVectors.load_word2vec_format(filename, binary=True)
182
        dim = W.vector_size
183
        return W, dim
184
185
    if filename.find('pkl') >= 0:
186
        import pickle
187
        with open(filename, 'rb') as f:
188
            W = pickle.load(f)
189
        for key, value in W.items():
190
            W[key] = np.array(list(map(float, value[1:-1].split(","))))
191
        dim = len(list(W.values())[0])
192
        return W, dim
193
194
    W = {}
195
    with open(filename, 'r') as f:
196
        for i, line in enumerate(f.readlines()):
197
            if i == 0:
198
                continue
199
            toks = line.strip().split()
200
            w = toks[0]
201
            vec = np.array(list(map(float, toks[1:])))
202
            W[w] = vec
203
    dim = len(list(W.values())[0])
204
    return W, dim
205
206
207
def load_bert(model_name_or_path):
208
    print(model_name_or_path)
209
    try:
210
        config = AutoConfig.from_pretrained(model_name_or_path)
211
        model = AutoModel.from_pretrained(
212
            model_name_or_path, config=config).to(device)
213
    except BaseException:
214
        model = torch.load(os.path.join(
215
            model_name_or_path, 'pytorch_model.bin')).to(device)
216
217
    try:
218
        model.output_hidden_states = False
219
    except BaseException:
220
        pass
221
222
    try:
223
        tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
224
    except BaseException:
225
        tokenizer = AutoTokenizer.from_pretrained(
226
            os.path.join(model_name_or_path, "../"))
227
    return model, tokenizer
228
229
230
def get_intersection(embedding_list, embedding_type_list):
231
    intersection_cui = set()
232
    checker = True
233
    for index, embed in enumerate(embedding_list):
234
        if embedding_type_list[index] == "cui":
235
            w, _ = load_embedding(embed)
236
            if checker:
237
                intersection_cui = set(list(w.keys()))
238
                checker = False
239
            else:
240
                intersection_cui = set(
241
                    list(w.keys())).intersection(intersection_cui)
242
    print(f"Intersection count: {len(intersection_cui)}")
243
    return list(intersection_cui)
244
245
246
if __name__ == "__main__":
247
    """
248
    embedding_list = ["../../embeddings/claims_codes_hs_300.txt",
249
                      "../../embeddings/GoogleNews-vectors-negative300.bin",
250
                      "../../models/2020_eng"]
251
    #embedding_type_list = ["cui", "word", "bert"]
252
    embedding_list = ["../../embeddings/wikipedia-pubmed-and-PMC-w2v.bin",
253
                      "../../embeddings/bio_nlp_vec/PubMed-shuffle-win-2.bin",
254
                      "../../embeddings/bio_nlp_vec/PubMed-shuffle-win-30.bin"]
255
    embedding_type_list = ["word", "word", "word"]
256
    embedding_list = ["../../embeddings/DeVine_etal_200.txt",
257
                      "/home/yz/pretraining_models/cui2vec.pkl"]
258
    embedding_type_list = ["cui", "cui"]
259
    """
260
    #mcsm([embedding_list[2], embedding_type_list[2]])
261
    """
262
    embedding_list = ["../../embeddings/claims_codes_hs_300.txt",
263
                      "../../embeddings/DeVine_etal_200.txt",
264
                      "/home/yz/pretraining_models/cui2vec.pkl"]
265
    embedding_type_list = ["cui", "cui", "cui"]
266
    mcsm(embedding_list, embedding_type_list, check_intersection=True)
267
    """
268
    #embedding_list = ["../../models/2020_eng", "../../models/2020_all"]
269
    #mcsm(embedding_list, ["bert"] * 2, check_intersection=True)
270
271
    """
272
    embedding_list = ["../../embeddings/wikipedia-pubmed-and-PMC-w2v.bin",
273
                      "../../embeddings/GoogleNews-vectors-negative300.bin",
274
                      "../../embeddings/bio_nlp_vec/PubMed-shuffle-win-2.bin",
275
                      "../../embeddings/bio_nlp_vec/PubMed-shuffle-win-30.bin"]
276
    mcsm(embedding_list, ["word"] * 4, check_intersection=True)
277
    """
278
279
    embedding_list = ["/home/yz/pretraining_models/bert-base-cased",
280
                      "/home/yz/pretraining_models/biobert_v1.1",
281
                      "/home/yz/pretraining_models/BiomedNLP-PubMedBERT-base-uncased-abstract",
282
                      "/home/yz/pretraining_models/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
283
                      "/home/yz/pretraining_models/kexinghuang_clinical",
284
                      "emilyalsentzer/Bio_ClinicalBERT",
285
                      "../../models/UMLSBert_nosty"]
286
    #mcsm(embedding_list, ["bert"] * 6, check_intersection=True)
287
    #mcsm(embedding_list, ["bert"] * 6)
288
    mcsm([embedding_list[-1]], ["bert"])