Switch to unified view

a b/test/embeddings_reimplement/codes_analysis.py
1
import os
2
import ipdb
3
from nltk.tokenize import word_tokenize
4
from icd9 import ICD9
5
from transformers import AutoConfig, AutoModel, AutoTokenizer
6
import torch
7
from tqdm import tqdm
8
import numpy as np
9
import sys
10
sys.path.append("../../pretrain")
11
from load_umls import UMLS
12
13
14
tree = ICD9('codes.json')
15
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
16
log_list = 1 / np.log2(list(range(2, 1001, 1)))
17
18
batch_size = 512
19
max_seq_length = 32
20
21
22
def get_icd9_pairs(icd9_set):
23
    icd9_pairs = {}
24
    with open('icd9_grp_file.txt', 'r', encoding="utf-8") as infile:
25
        data = infile.readlines()
26
        for row in data:
27
            codes, name = row.strip().split('#')
28
            name = name.strip()
29
            codes = codes.strip().split(' ')
30
            new_codes = set([])
31
            for code in codes:
32
                if code in icd9_set:
33
                    new_codes.add(code)
34
                elif len(code) > 5 and code[:5] in icd9_set:
35
                    new_codes.add(code[:5])
36
                elif len(code) > 4 and code[:3] in icd9_set:
37
                    new_codes.add(code[:3])
38
            codes = list(new_codes)
39
40
            if len(codes) > 1:
41
                for idx, code in enumerate(codes):
42
                    if code not in icd9_pairs:
43
                        icd9_pairs[code] = set([])
44
                    icd9_pairs[code].update(set(codes[:idx]))
45
                    icd9_pairs[code].update(set(codes[idx+1:]))
46
    return icd9_pairs
47
48
49
def get_coarse_icd9_pairs(icd9_set):
50
    icd9_pairs = {}
51
    ccs_to_icd9 = {}
52
    with open('ccs_coarsest.txt', 'r', encoding="utf-8") as infile:
53
        data = infile.readlines()
54
        currect_ccs = ''
55
        for row in data:
56
            if row[:10].strip() != '':
57
                current_ccs = row[:10].strip()
58
                ccs_to_icd9[current_ccs] = set([])
59
            elif row.strip() != '':
60
                ccs_to_icd9[current_ccs].update(set(row.strip().split(' ')))
61
62
    ccs_coarse = {}
63
    for ccs in list(ccs_to_icd9.keys()):
64
        ccs_eles = ccs.split('.')
65
        if len(ccs_eles) >= 2:
66
            code = ccs_eles[0] + '.' + ccs_eles[1]
67
            if code not in ccs_coarse:
68
                ccs_coarse[code] = set([])
69
            ccs_coarse[code].update(ccs_to_icd9[ccs])
70
71
    for ccs in list(ccs_coarse.keys()):
72
        new_codes = set([])
73
        for code in ccs_coarse[ccs]:
74
            if len(code) > 3:
75
                new_code = code[:3] + '.' + code[3:]
76
            code = new_code
77
            if code in icd9_set:
78
                new_codes.add(code)
79
            elif len(code) > 5 and code[:5] in icd9_set:
80
                new_codes.add(code[:5])
81
            elif len(code) > 4 and code[:3] in icd9_set:
82
                new_codes.add(code[:3])
83
        codes = list(new_codes)
84
        if len(codes) > 1:
85
            for idx, code in enumerate(codes):
86
                if code not in icd9_pairs:
87
                    icd9_pairs[code] = set([])
88
                icd9_pairs[code].update(set(codes[:idx]))
89
                icd9_pairs[code].update(set(codes[idx+1:]))
90
    return icd9_pairs
91
92
93
def get_cui_concept_mappings():
94
    concept_to_cui_hdr = '2b_concept_ID_to_CUI.txt'
95
    concept_to_cui = {}
96
    cui_to_concept = {}
97
    with open(concept_to_cui_hdr, 'r', encoding="utf-8") as infile:
98
        lines = infile.readlines()
99
        for line in lines:
100
            concept = line.split('\t')[0]
101
            cui = line.split('\t')[1].split('\r')[0].strip()
102
            concept_to_cui[concept] = cui
103
            cui_to_concept[cui] = concept
104
    return concept_to_cui, cui_to_concept
105
106
107
def get_icd9_reverse_dict(icd9_dict):
108
    reverse_dict = {}
109
    for key, value in icd9_dict.items():
110
        for v in value:
111
            reverse_dict[v] = key
112
    return reverse_dict
113
114
115
def get_icd9_cui_mappings():
116
    cui_to_icd9 = {}
117
    icd9_to_cui = {}
118
    with open('cui_icd9.txt', 'r', encoding="utf-8") as infile:
119
        data = infile.readlines()
120
        for row in data:
121
            ele = row.strip().split('|')
122
            if ele[11] == 'ICD9CM':
123
                cui = ele[0]
124
                icd9 = ele[10]
125
                if cui not in cui_to_icd9 and icd9 != '' and '-' not in icd9:
126
                    cui_to_icd9[cui] = icd9
127
                    icd9_to_cui[icd9] = cui
128
    return cui_to_icd9, icd9_to_cui
129
130
131
def get_icd9_to_description():
132
    icd9_to_description = {}
133
    with open('CMS32_DESC_LONG_DX.txt', 'r', encoding='latin-1') as infile:
134
        data = infile.readlines()
135
        for row in data:
136
            icd9 = row.strip()[:6].strip()
137
            if len(icd9) > 3:
138
                icd9 = icd9[:3] + '.' + icd9[3:]
139
            description = row.strip()[6:].strip()
140
            icd9_to_description[icd9] = description
141
    return icd9_to_description
142
143
144
def mrm_ccs(embedding_list, embedding_type_list, k=40, check_intersection=False):
145
    cui_to_icd9, icd9_to_cui = get_icd9_cui_mappings()
146
147
    if check_intersection:
148
        if not os.path.exists("intersection.txt"):
149
            intersection_cui = get_intersection(
150
                embedding_list, embedding_type_list)
151
            with open("intersection.txt", "w", encoding="utf-8") as f:
152
                for cui in intersection_cui:
153
                    f.write(cui.strip() + "\n")
154
        else:
155
            with open("intersection.txt", "r", encoding="utf-8") as f:
156
                lines = f.readlines()
157
            intersection_cui = [line.strip() for line in lines]
158
159
    umls = UMLS("../../umls", only_load_dict=True)
160
161
    if check_intersection:
162
        cui_list = [cui for cui in intersection_cui
163
                    if cui in list(cui_to_icd9.keys())]
164
    else:
165
        cui_list = list(cui_to_icd9.keys())
166
167
    icd9_list = [cui_to_icd9[cui] for cui in cui_list]
168
    icd9_set = set(icd9_list)
169
    icd9_pair = get_icd9_pairs(icd9_set)
170
    icd9_coarse_pair = get_coarse_icd9_pairs(icd9_set)
171
    icd9_to_description = get_icd9_to_description()
172
173
    #icd9_reverse_dict_pair = get_icd9_reverse_dict(icd9_pair)
174
    #icd9_reverse_dict_coarse_pair = get_icd9_reverse_dict(icd9_coarse_pair)
175
176
    #ipdb.set_trace()
177
178
    # type label
179
    # Only part of the icd is calculated as center
180
    # icd9_to_check = set(icd9_pairs.keys())
181
    # icd9_to_check.intersection_update(set(icd9_to_idx.keys()))
182
    pair_center_label = []
183
    #pair_label = []
184
    coarse_pair_center_label = []
185
    #coarse_pair_label = []
186
    for cui in cui_list:
187
        if cui_to_icd9[cui] in icd9_pair:
188
            pair_center_label.append(1)
189
        else:
190
            pair_center_label.append(0)
191
        #pair_label.append(icd9_reverse_dict_pair[cui_to_icd9[cui]])
192
193
        if cui_to_icd9[cui] in icd9_coarse_pair:
194
            coarse_pair_center_label.append(1)
195
        else:
196
            coarse_pair_center_label.append(0)
197
        #coarse_pair_label.append(icd9_reverse_dict_coarse_pair[cui_to_icd9[cui]])
198
199
    # generate_description
200
    description = []
201
    for cui in cui_list:
202
        if cui in cui_to_icd9 and cui_to_icd9[cui] in icd9_to_description:
203
            description.append(icd9_to_description[cui_to_icd9[cui]])
204
        elif cui in cui_to_icd9 and tree.find(cui_to_icd9[cui]):
205
            description.append(tree.find(cui_to_icd9[cui]).description)
206
        elif cui in umls.cui2str:
207
            description.append(list(umls.cui2str[cui])[0])
208
        else:
209
            description.append("")
210
            print(f"Can not find description for {cui}")
211
212
    #ipdb.set_trace()
213
214
    opt = []
215
    for index, embedding in enumerate(embedding_list):
216
        print("*************************")
217
        if embedding_type_list[index].lower() == "cui":
218
            opt.append(mrm_ccs_cui(embedding, icd9_list, cui_list, pair_center_label, icd9_pair, k))
219
            opt.append(mrm_ccs_cui(embedding, icd9_list, cui_list, coarse_pair_center_label, icd9_coarse_pair, k))
220
        if embedding_type_list[index].lower() == "word":
221
            opt.append(mrm_ccs_word(embedding, icd9_list, description, pair_center_label, icd9_pair, k))
222
            opt.append(mrm_ccs_word(embedding, icd9_list, description, coarse_pair_center_label, icd9_coarse_pair, k))
223
        if embedding_type_list[index].lower() == "bert":
224
            opt.append(mrm_ccs_bert(embedding, icd9_list, description, pair_center_label, icd9_pair, k, summary_method="MEAN"))
225
            opt.append(mrm_ccs_bert(embedding, icd9_list, description, coarse_pair_center_label, icd9_coarse_pair, k, summary_method="MEAN"))
226
            opt.append(mrm_ccs_bert(embedding, icd9_list, description, pair_center_label, icd9_pair, k, summary_method="CLS"))
227
            opt.append(mrm_ccs_bert(embedding, icd9_list, description, coarse_pair_center_label, icd9_coarse_pair, k, summary_method="CLS"))
228
    return opt
229
230
231
def mrm_ccs_cui(cui_embedding, icd9_list, cui_list, center_label, pair, k=40):
232
    w, _ = load_embedding(cui_embedding)
233
    print(f"All cui count:{len(cui_list)}")
234
    new_cui_list = []
235
    #new_label = []
236
    new_center_label = []
237
    new_icd9_list = []
238
    for index, cui in enumerate(cui_list):
239
        if cui in w:
240
            new_cui_list.append(cui)
241
            new_center_label.append(center_label[index])
242
            new_icd9_list.append(icd9_list[index])
243
            #new_label.append(label[index])
244
    #print(f"Check cui count:{len(new_cui_list)}")
245
246
    term_embedding = np.array([w[cui] for cui in new_cui_list])
247
248
    return calculate_mrm_ccs(term_embedding, new_icd9_list, new_center_label, pair, k=k)
249
250
251
def mrm_ccs_word(word_embedding, icd9_list, description, center_label, pair, k=40):
252
    w, dim = load_embedding(word_embedding)
253
254
    print(f"All cui count:{len(description)}")
255
    #cui_str = [[word for word in word_tokenize(
256
    #    list(umls.cui2str[cui])[0]) if word in w] for cui in cui_list]
257
    cui_str = []
258
    #new_label = []
259
    new_center_label = []
260
    new_icd9_list = []
261
    for index, des in enumerate(description):
262
        tokenize_result = [word for word in word_tokenize(des) if word in w]
263
        if len(tokenize_result) > 0:
264
            cui_str.append(tokenize_result)
265
            new_center_label.append(center_label[index])
266
            #new_label.append(label[index])
267
            new_icd9_list.append(icd9_list[index])
268
269
    check_count = 0
270
    for index, cui in tqdm(enumerate(cui_str)):
271
            tmp_emb = np.zeros((dim))
272
            for word in cui:
273
                tmp_emb += w[word]
274
275
            if check_count == 0:
276
                term_embedding = tmp_emb
277
            else:
278
                term_embedding = np.concatenate(
279
                    (term_embedding, tmp_emb), axis=0)
280
            check_count += 1
281
    term_embedding = term_embedding.reshape((-1, dim))
282
283
    #print(f"Check cui count:{check_count}")
284
285
    return calculate_mrm_ccs(term_embedding, new_icd9_list, new_center_label, pair, k=k)
286
287
288
def mrm_ccs_bert(bert_embedding, icd9_list, description, center_label, pair, k=40, summary_method="MEAN"):
289
    #print(f"Check cui count:{len(description)}")
290
    model, tokenizer = load_bert(bert_embedding)
291
    model.eval()
292
293
    input_ids = []
294
    for des in tqdm(description):
295
        input_ids.append(tokenizer.encode_plus(
296
            des, max_length=max_seq_length, add_special_tokens=True,
297
            truncation=True, pad_to_max_length=True)['input_ids'])
298
299
    count = len(input_ids)
300
    now_count = 0
301
    with tqdm(total=count) as pbar:
302
        with torch.no_grad():
303
            while now_count < count:
304
                input_gpu_0 = torch.LongTensor(input_ids[now_count:min(
305
                    now_count + batch_size, count)]).to(device)
306
                if summary_method == "CLS":
307
                    embed = model(input_gpu_0)[1]
308
                if summary_method == "MEAN":
309
                    embed = torch.mean(model(input_gpu_0)[0], dim=1)
310
                embed_np = embed.cpu().detach().numpy()
311
                if now_count == 0:
312
                    term_embedding = embed_np
313
                else:
314
                    term_embedding = np.concatenate(
315
                        (term_embedding, embed_np), axis=0)
316
                update = min(now_count + batch_size, count) - now_count
317
                now_count = now_count + update
318
                pbar.update(update)
319
320
    return calculate_mrm_ccs(term_embedding, icd9_list, center_label, pair, k=k)
321
322
323
def calculate_mrm_ccs(term_embedding, icd9_list, center_label, pair, k, normalize=True):
324
    # term_embedding: term_count * embedding_dim
325
    # term_type: term_count
326
    term_embedding = torch.FloatTensor(term_embedding).to(device)
327
    embedding_norm = torch.norm(
328
        term_embedding, p=2, dim=1, keepdim=True).clamp(min=1e-12)
329
    term_embedding = torch.div(term_embedding, embedding_norm)
330
    del embedding_norm
331
332
    output = []
333
    check_count = 0
334
335
    count = {}
336
    for icd9 in tqdm(pair):
337
        count[icd9] = 0
338
        for v in pair[icd9]:
339
            if v in icd9_list:
340
                count[icd9] += 1
341
342
    for index, icd9 in tqdm(enumerate(icd9_list)):
343
        if center_label[index] == 1 and icd9 in pair:
344
            now = term_embedding[index]
345
            score = 0.0
346
            similarity = torch.matmul(term_embedding, now)
347
            # The most similar term is itself
348
            _, indices = torch.topk(similarity, k=k + 1)
349
            group = pair[icd9]
350
            for i in range(1, k + 1, 1):
351
                if icd9_list[indices[i]] in group:
352
                    score += log_list[i - 1]
353
            if normalize:
354
                if score > 0:
355
                    score /= sum(log_list[0:min(k, count[icd9])])
356
            output.append(score)
357
            check_count += 1
358
    del term_embedding
359
360
    if len(output) >= 1:
361
        score = sum(output) / len(output)
362
    else:
363
        score = 0.
364
    print(f"Check count: {check_count}")
365
    print(score)
366
    return score
367
368
369
def load_embedding(filename):
370
    print(filename)
371
    if filename.find('bin') >= 0:
372
        from gensim import models
373
        W = models.KeyedVectors.load_word2vec_format(filename, binary=True)
374
        dim = W.vector_size
375
        return W, dim
376
377
    if filename.find('pkl') >= 0:
378
        import pickle
379
        with open(filename, 'rb') as f:
380
            W = pickle.load(f)
381
        for key, value in W.items():
382
            W[key] = np.array(list(map(float, value[1:-1].split(","))))
383
        dim = len(list(W.values())[0])
384
        return W, dim
385
386
    W = {}
387
    with open(filename, 'r') as f:
388
        for i, line in enumerate(f.readlines()):
389
            if i == 0:
390
                continue
391
            toks = line.strip().split()
392
            w = toks[0]
393
            vec = np.array(list(map(float, toks[1:])))
394
            W[w] = vec
395
    dim = len(list(W.values())[0])
396
    return W, dim
397
398
399
def load_bert(model_name_or_path):
400
    print(model_name_or_path)
401
    try:
402
        config = AutoConfig.from_pretrained(model_name_or_path)
403
        model = AutoModel.from_pretrained(
404
            model_name_or_path, config=config).to(device)
405
    except BaseException:
406
        model = torch.load(os.path.join(
407
            model_name_or_path, 'pytorch_model.bin')).to(device)
408
409
    try:
410
        model.output_hidden_states = False
411
    except BaseException:
412
        pass
413
414
    try:
415
        tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
416
    except BaseException:
417
        tokenizer = AutoTokenizer.from_pretrained(
418
            os.path.join(model_name_or_path, "../"))
419
    return model, tokenizer
420
421
422
def get_intersection(embedding_list, embedding_type_list):
423
    intersection_cui = set()
424
    checker = True
425
    for index, embed in enumerate(embedding_list):
426
        if embedding_type_list[index] == "cui":
427
            w, _ = load_embedding(embed)
428
            if checker:
429
                intersection_cui = set(list(w.keys()))
430
                checker = False
431
            else:
432
                intersection_cui = set(
433
                    list(w.keys())).intersection(intersection_cui)
434
    print(f"Intersection count: {len(intersection_cui)}")
435
    return list(intersection_cui)
436
437
438
if __name__ == "__main__":
439
    
440
    embedding_list = ["../../embeddings/claims_codes_hs_300.txt",
441
                      "../../embeddings/GoogleNews-vectors-negative300.bin",
442
                      "../../models/2020_eng"]
443
    embedding_type_list = ["cui", "word", "bert"]
444
    mrm_ccs(embedding_list, embedding_type_list)#, normalize=True)
445
    """
446
    embedding_list = ["../../embeddings/wikipedia-pubmed-and-PMC-w2v.bin",
447
                      "../../embeddings/bio_nlp_vec/PubMed-shuffle-win-2.bin",
448
                      "../../embeddings/bio_nlp_vec/PubMed-shuffle-win-30.bin",
449
                      "/home/yz/pretraining_models/cui2vec.pkl",
450
                      "../../embeddings/DeVine_etal_200.txt"]
451
    embedding_type_list = ["word", "word", "word", "cui", "cui"]
452
    mrm_ccs(embedding_list[3:], embedding_type_list[3:])
453
    
454
    embedding_list = ["../../models/2020_all",
455
                      "/home/yz/pretraining_models/bert-base-cased",
456
                      "/home/yz/pretraining_models/biobert_v1.1",
457
                      "/home/yz/pretraining_models/BiomedNLP-PubMedBERT-base-uncased-abstract",
458
                      "/home/yz/pretraining_models/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
459
                      "/home/yz/pretraining_models/kexinghuang_clinical",
460
                      "emilyalsentzer/Bio_ClinicalBERT"]
461
    """
462
    #mrm_ccs(embedding_list, ["bert"] * 7)
463
    #mrm_ccs([embedding_list[6]], ["bert"])