Switch to unified view

a b/test/embeddings_reimplement/ndfrt_analysis.py
1
from transformers import AutoConfig, AutoModel, AutoTokenizer
2
import torch
3
from tqdm import tqdm
4
import numpy as np
5
import sys
6
sys.path.append("../../pretrain")
7
from load_umls import UMLS
8
from nltk.tokenize import word_tokenize
9
import os
10
import ipdb
11
12
13
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
14
15
batch_size = 512
16
max_seq_length = 32
17
18
def get_drug_diseases_to_check(concept_filename):
19
    query_to_targets = {}
20
    with open(concept_filename, 'r') as infile:
21
        data = infile.readlines()
22
        for row in data:
23
            drug, diseases = row.strip().split(':')
24
            diseases = diseases.split(',')[:-1]
25
            disease_cui_set = set([])
26
            for disease in diseases:
27
                disease_cui_set.add(disease)
28
            if len(disease_cui_set) > 0:
29
                query_to_targets[drug] = disease_cui_set
30
31
    cui_list = set()
32
    for query, targets in query_to_targets.items():
33
        cui_list.update([query])
34
        cui_list.update(targets)
35
    return query_to_targets, list(cui_list)
36
37
def normalize(tensor):
38
    norm = torch.norm(tensor, p=2, dim=1, keepdim=True).clamp(min=1e-12)
39
    return torch.div(tensor, norm)
40
41
def calculate_mrm_ndfrt_origin(term_embedding, cui_list, query_to_targets, k):
42
    return calculate_mrm_ndfrt_delta(term_embedding, cui_list, query_to_targets, None, k)
43
44
45
def calculate_mrm_ndfrt_q2t(term_embedding, cui_list, query_to_targets, k):
46
    delta_list = []
47
48
    term_embedding = torch.FloatTensor(term_embedding).to(device)
49
    norm_embedding = normalize(term_embedding)
50
51
    id2cui = {i:cui_list[i] for i in range(len(cui_list))}
52
    cui2id = {cui:index for index, cui in id2cui.items()}
53
54
    for query, targets in query_to_targets.items():
55
        if query in cui2id:
56
            for target in targets:
57
                if target in cui2id:
58
                    delta = term_embedding[cui2id[query]] - term_embedding[cui2id[target]]
59
                    delta_list.append(delta)
60
61
    overall_output = []
62
    for _, delta in tqdm(enumerate(delta_list)):
63
        output = []
64
        for query, targets in query_to_targets.items():
65
            if query in cui2id:
66
                find_embedding = term_embedding[cui2id[query]] - delta
67
                similarity = torch.matmul(norm_embedding, find_embedding)
68
                _, indices = torch.topk(similarity, k=k + 1)
69
                find_cui = [cui_list[index] for index in indices[1:]]
70
                score = 0.
71
                for cui in find_cui:
72
                    if cui in targets:
73
                        score = 1.
74
                        break
75
                output.append(score)
76
        if len(output) > 0:
77
            score = sum(output) / len(output)
78
        else:
79
            score = 0.  
80
        overall_output.append(score)
81
82
    if len(overall_output) > 0:
83
        overall_score = sum(overall_output) / len(overall_output)
84
        overall_max = max(overall_output)
85
    else:
86
        overall_score = 0
87
        overall_max = 0
88
    return overall_score, overall_max
89
90
91
def calculate_mrm_ndfrt_delta(term_embedding, cui_list, query_to_targets, delta=None, k=40):
92
    term_embedding = torch.FloatTensor(term_embedding).to(device)
93
    norm_embedding = normalize(term_embedding)
94
95
    id2cui = {i:cui_list[i] for i in range(len(cui_list))}
96
    cui2id = {cui:index for index, cui in id2cui.items()}
97
98
    output = []
99
    check_count = 0
100
    for query, targets in query_to_targets.items():
101
        if query in cui2id:
102
            query_embedding = term_embedding[cui2id[query]]
103
            if delta is None:
104
                find_embedding = query_embedding
105
            else:
106
                find_embedding = query_embedding - torch.FloatTensor(delta).to(device)
107
            similarity = torch.matmul(norm_embedding, find_embedding)
108
            _, indices = torch.topk(similarity, k=k + 1)
109
            find_cui = [cui_list[index] for index in indices[1:]]
110
            score = 0.
111
            for cui in find_cui:
112
                if cui in targets:
113
                    score = 1.
114
                    break
115
            output.append(score)
116
            check_count += 1
117
    del term_embedding
118
119
    if len(output) > 0:
120
        score = sum(output) / len(output)
121
    else:
122
        score = 0.
123
124
    """
125
    print(f"Check count: {check_count}")
126
    print(score)
127
    """
128
129
    return score
130
131
132
def mrm_ndfrt_cui(cui_embedding, umls, cui_list, query_to_targets, k, method):
133
    w, _ = load_embedding(cui_embedding)
134
135
    new_cui_list = [cui for cui in cui_list if cui in w]
136
    term_embedding = np.array([w[cui] for cui in new_cui_list])
137
138
    print(f"Cui count:{len(new_cui_list)}")
139
140
    if method == "origin":
141
        score = calculate_mrm_ndfrt_origin(term_embedding, new_cui_list, query_to_targets, k)
142
        print(f"Origin: {score}")
143
    if method == "all":
144
        score = calculate_mrm_ndfrt_q2t(term_embedding, new_cui_list, query_to_targets, k)
145
        average_score, max_score = score
146
        print(f"Average: {average_score}")
147
        print(f"Max: {max_score}")
148
    return score
149
150
151
def mrm_ndfrt_word(word_embedding, umls, cui_list, query_to_targets, k, method):
152
    w, dim = load_embedding(word_embedding)
153
154
    print("Tokenize and calculate avg embedding.")
155
    cui_str = [[word for word in word_tokenize(
156
        list(umls.cui2str[cui])[0]) if word in w] for cui in cui_list if cui in umls.cui2str]
157
158
    new_cui_list = []
159
    check_count = 0
160
    for index, des in enumerate(cui_str):
161
        if len(des) > 0:
162
            tmp_emb = np.zeros((dim))
163
            for word in des:
164
                tmp_emb += w[word]
165
166
            if check_count == 0:
167
                term_embedding = tmp_emb
168
            else:
169
                term_embedding = np.concatenate(
170
                    (term_embedding, tmp_emb), axis=0)
171
            check_count += 1
172
            new_cui_list.append(cui_list[index])
173
    term_embedding = term_embedding.reshape((-1, dim))
174
175
    print(f"Cui count:{len(new_cui_list)}")
176
177
    if method == "origin":
178
        score = calculate_mrm_ndfrt_origin(term_embedding, new_cui_list, query_to_targets, k)
179
        print(f"Origin: {score}")
180
    if method == "all":
181
        score = calculate_mrm_ndfrt_q2t(term_embedding, new_cui_list, query_to_targets, k)
182
        average_score, max_score = score
183
        print(f"Average: {average_score}")
184
        print(f"Max: {max_score}")
185
    return score
186
187
188
def mrm_ndfrt_bert(bert_embedding, umls, cui_list, query_to_targets, k, method, summary_method):
189
    print(summary_method)
190
    model, tokenizer = load_bert(bert_embedding)
191
    model.eval()
192
193
    input_ids = []
194
    new_cui_list = []
195
    for cui in cui_list:
196
        if cui in umls.cui2str:
197
            input_ids.append(tokenizer.encode_plus(
198
                list(umls.cui2str[cui])[
199
                    0], max_length=max_seq_length, add_special_tokens=True,
200
                truncation=True, pad_to_max_length=True)['input_ids'])
201
            new_cui_list.append(cui)
202
203
    count = len(input_ids)
204
    now_count = 0
205
    # with tqdm(total=count) as pbar:
206
    with torch.no_grad():
207
        while now_count < count:
208
            input_gpu_0 = torch.LongTensor(input_ids[now_count:min(
209
                now_count + batch_size, count)]).to(device)
210
            if summary_method == "CLS":
211
                embed = model(input_gpu_0)[1]
212
            if summary_method == "MEAN":
213
                embed = torch.mean(model(input_gpu_0)[0], dim=1)
214
            embed_np = embed.cpu().detach().numpy()
215
            if now_count == 0:
216
                term_embedding = embed_np
217
            else:
218
                term_embedding = np.concatenate((term_embedding, embed_np), axis=0)
219
            update = min(now_count + batch_size, count) - now_count
220
            now_count = now_count + update
221
            # pbar.update(update)
222
223
    print(f"Cui count:{len(new_cui_list)}")
224
225
    if method == "origin":
226
        score = calculate_mrm_ndfrt_origin(term_embedding, new_cui_list, query_to_targets, k)
227
        print(f"Origin: {score}")
228
    if method == "all":
229
        score = calculate_mrm_ndfrt_q2t(term_embedding, new_cui_list, query_to_targets, k)
230
        average_score, max_score = score
231
        print(f"Average: {average_score}")
232
        print(f"Max: {max_score}")
233
    if method in ["may_treat", "may_prevent"]:
234
        beta_path = os.path.join(bert_embedding, "run", "1000000", "rel embedding")
235
        with open(os.path.join(beta_path, "metadata.tsv"), "r", encoding="utf-8") as f:
236
            metadata = f.readlines()
237
        metadata = [line.strip() for line in metadata]
238
        with open(os.path.join(beta_path, "tensors.tsv"), "r", encoding="utf-8") as f:
239
            tensor = f.readlines()
240
241
        tensor = [[float(num) for num in line.split("\t")] for line in tensor]
242
        for index, title in enumerate(metadata):
243
            if title == method:
244
                delta = tensor[index]
245
        
246
        score = calculate_mrm_ndfrt_delta(term_embedding, new_cui_list, query_to_targets, delta, k)
247
        print(f"{method}: {score}")
248
    return score
249
250
251
def mrm_ndfrt(embedding_list, embedding_type_list, concept_filename, k=40, check_intersection=True):
252
    if check_intersection:
253
        if not os.path.exists("intersection.txt"):
254
            intersection_cui = get_intersection(
255
                embedding_list, embedding_type_list)
256
            with open("intersection.txt", "w", encoding="utf-8") as f:
257
                for cui in intersection_cui:
258
                    f.write(cui.strip() + "\n")
259
        else:
260
            with open("intersection.txt", "r", encoding="utf-8") as f:
261
                lines = f.readlines()
262
            intersection_cui = [line.strip() for line in lines]
263
264
    query_to_targets, cui_list = get_drug_diseases_to_check(concept_filename)
265
    umls = UMLS("../../umls", only_load_dict=True) # source_range='SNOMEDCT_US')#, only_load_dict=True)
266
267
    if check_intersection:
268
        cui_list = [cui for cui in cui_list if cui in intersection_cui]
269
270
    #cui_list = [cui for cui in umls.cui2str if umls.cui2sty[cui] in sty_list]
271
    #cui_list = [cui for cui in cui_list if cui in umls.sty_list]
272
273
    """
274
    for cui in cui_list:
275
        if not cui in umls.cui2str:
276
            print(cui)
277
278
    ipdb.set_trace()
279
    """
280
281
    opt = []
282
    """
283
    # Origin
284
    print("ORIGIN")
285
    for index, embedding in enumerate(embedding_list):
286
        if embedding_type_list[index].lower() == "cui":
287
            opt.append(mrm_ndfrt_cui(embedding, umls, cui_list, query_to_targets, k, "origin"))
288
        if embedding_type_list[index].lower() == "word":
289
            opt.append(mrm_ndfrt_word(embedding, umls, cui_list, query_to_targets, k, "origin"))
290
        if embedding_type_list[index].lower() == "bert":
291
            #opt.append(mrm_ndfrt_bert(embedding, umls, cui_list,
292
            #                     query_to_targets, k, "origin", summary_method="MEAN"))
293
            opt.append(mrm_ndfrt_bert(embedding, umls, cui_list,
294
                                 query_to_targets, k, "origin", summary_method="CLS"))
295
296
    # For UMLSBert
297
    for index, embedding in enumerate(embedding_list):
298
        if embedding_type_list[index].lower() == "bert":
299
            print("BETA")
300
            beta_path = os.path.join(embedding, "run", "1000000", "rel embedding")
301
            if os.path.exists(beta_path):
302
                if concept_filename.find('treat') >= 0:
303
                    method = "may_treat"
304
                else:
305
                    method = "may_prevent"
306
                #opt.append(mrm_ndfrt_bert(embedding, umls, cui_list,
307
                #                 query_to_targets, k, method, summary_method="MEAN"))
308
                opt.append(mrm_ndfrt_bert(embedding, umls, cui_list,
309
                                 query_to_targets, k, method, summary_method="CLS"))                
310
311
    # For average and max
312
313
    print("ALL")
314
    for index, embedding in enumerate(embedding_list):
315
        if embedding_type_list[index].lower() == "cui":
316
            opt.append(mrm_ndfrt_cui(embedding, umls, cui_list, query_to_targets, k, "all"))
317
        if embedding_type_list[index].lower() == "word":
318
            opt.append(mrm_ndfrt_word(embedding, umls, cui_list, query_to_targets, k, "all"))
319
        if embedding_type_list[index].lower() == "bert":
320
            #opt.append(mrm_ndfrt_bert(embedding, umls, cui_list,
321
            #                     query_to_targets, k, "all", summary_method="MEAN"))
322
            opt.append(mrm_ndfrt_bert(embedding, umls, cui_list,
323
                                 query_to_targets, k, "all", summary_method="CLS"))
324
    """
325
    for index, embedding in enumerate(embedding_list):
326
        if embedding_type_list[index].lower() == "cui":
327
            opt.append(mrm_ndfrt_cui(embedding, umls, cui_list, query_to_targets, k, "origin"))
328
            opt.append(mrm_ndfrt_cui(embedding, umls, cui_list, query_to_targets, k, "all"))
329
        if embedding_type_list[index].lower() == "word":
330
            opt.append(mrm_ndfrt_word(embedding, umls, cui_list, query_to_targets, k, "origin"))
331
            opt.append(mrm_ndfrt_word(embedding, umls, cui_list, query_to_targets, k, "all"))
332
        if embedding_type_list[index].lower() == "bert":
333
            opt.append(mrm_ndfrt_bert(embedding, umls, cui_list,
334
                                      query_to_targets, k, "origin", summary_method="CLS"))
335
            beta_path = os.path.join(embedding, "run", "1000000", "rel embedding")
336
            if os.path.exists(beta_path):
337
                if concept_filename.find('treat') >= 0:
338
                    method = "may_treat"
339
                else:
340
                    method = "may_prevent"  
341
                opt.append(mrm_ndfrt_bert(embedding, umls, cui_list,
342
                                            query_to_targets, k, method, summary_method="CLS"))
343
            opt.append(mrm_ndfrt_bert(embedding, umls, cui_list,
344
                                      query_to_targets, k, "all", summary_method="CLS"))                   
345
346
    return opt
347
348
def load_embedding(filename):
349
    print(filename)
350
    if filename.find('bin') >= 0:
351
        from gensim import models
352
        W = models.KeyedVectors.load_word2vec_format(filename, binary=True)
353
        dim = W.vector_size
354
        return W, dim
355
356
    if filename.find('pkl') >= 0:
357
        import pickle
358
        with open(filename, 'rb') as f:
359
            W = pickle.load(f)
360
        for key, value in W.items():
361
            W[key] = np.array(list(map(float, value[1:-1].split(","))))
362
        dim = len(list(W.values())[0])
363
        return W, dim
364
365
    W = {}
366
    with open(filename, 'r') as f:
367
        for i, line in enumerate(f.readlines()):
368
            if i == 0:
369
                continue
370
            toks = line.strip().split()
371
            w = toks[0]
372
            vec = np.array(list(map(float, toks[1:])))
373
            W[w] = vec
374
    dim = len(list(W.values())[0])
375
    return W, dim
376
377
378
def load_bert(model_name_or_path):
379
    print(model_name_or_path)
380
    try:
381
        config = AutoConfig.from_pretrained(model_name_or_path)
382
        model = AutoModel.from_pretrained(
383
            model_name_or_path, config=config).to(device)
384
    except BaseException:
385
        model = torch.load(os.path.join(
386
            model_name_or_path, 'pytorch_model.bin')).to(device)
387
388
    try:
389
        model.output_hidden_states = False
390
    except BaseException:
391
        pass
392
393
    try:
394
        tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
395
    except BaseException:
396
        tokenizer = AutoTokenizer.from_pretrained(
397
            os.path.join(model_name_or_path, "../"))
398
    return model, tokenizer
399
400
def get_intersection(embedding_list, embedding_type_list):
401
    intersection_cui = set()
402
    checker = True
403
    for index, embed in enumerate(embedding_list):
404
        if embedding_type_list[index] == "cui":
405
            w, _ = load_embedding(embed)
406
            if checker:
407
                intersection_cui = set(list(w.keys()))
408
                checker = False
409
            else:
410
                intersection_cui = set(
411
                    list(w.keys())).intersection(intersection_cui)
412
    print(f"Intersection count: {len(intersection_cui)}")
413
    return list(intersection_cui)
414
415
if __name__ == "__main__":
416
    """
417
    embedding_list = ["../../embeddings/claims_codes_hs_300.txt",
418
                      "../../embeddings/GoogleNews-vectors-negative300.bin",
419
                      "../../models/2020_eng"]
420
    embedding_type_list = ["cui", "word", "bert"]
421
    mrm_ndfrt(embedding_list, embedding_type_list, "may_prevent_cui.txt", check_intersection=False)
422
    """
423
424
    embedding_list = ["../../embeddings/wikipedia-pubmed-and-PMC-w2v.bin",
425
                      "../../embeddings/bio_nlp_vec/PubMed-shuffle-win-2.bin",
426
                      "../../embeddings/bio_nlp_vec/PubMed-shuffle-win-30.bin"]
427
    embedding_type_list = ["word", "word", "word"]
428
    embedding_list += ["../../embeddings/DeVine_etal_200.txt",
429
                      "/home/yz/pretraining_models/cui2vec.pkl"]
430
    embedding_type_list += ["cui", "cui"]
431
    embedding_list += ["../../models/2020_all",
432
                      "/home/yz/pretraining_models/bert-base-cased",
433
                      "/home/yz/pretraining_models/biobert_v1.1",
434
                      "/home/yz/pretraining_models/BiomedNLP-PubMedBERT-base-uncased-abstract",
435
                      "/home/yz/pretraining_models/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
436
                      "/home/yz/pretraining_models/kexinghuang_clinical",
437
                      "emilyalsentzer/Bio_ClinicalBERT"]
438
    embedding_type_list += ["bert"] * 7
439
440
    #mrm_ndfrt(embedding_list, embedding_type_list, "may_treat_cui.txt", check_intersection=True)
441
    mrm_ndfrt(embedding_list, embedding_type_list, "may_treat_cui.txt", check_intersection=False)
442
    #mrm_ndfrt(embedding_list[-6:], embedding_type_list[-6:], "may_prevent_cui.txt", check_intersection=True)
443
    #mrm_ndfrt(embedding_list, embedding_type_list, "may_prevent_cui.txt", check_intersection=False)