Diff of /test/mantra/test.py [000000] .. [c3444c]

Switch to unified view

a b/test/mantra/test.py
1
from gensim import models
2
import os
3
import sys
4
sys.path.append("../../")
5
from pretrain.load_umls import UMLS
6
import torch
7
import numpy as np
8
from transformers import AutoTokenizer, AutoModel, AutoConfig
9
from data_util import load
10
import tqdm
11
12
batch_size = 128
13
device = "cuda:0"
14
15
16
def get_umls():
17
    umls_label = []
18
    umls_label_set = set()
19
    umls_des = []
20
    umls = UMLS("../../umls", source_range=["MSH", "SNOMEDCT_US", "MDR"], only_load_dict=True)
21
    for cui in tqdm.tqdm(umls.cui2str):
22
        if not cui in umls_label_set:
23
            tmp_str = list(umls.cui2str[cui])
24
            umls_label.extend([cui] * len(tmp_str))
25
            umls_des.extend(tmp_str)
26
            umls_label_set.update([cui])
27
    print(len(umls_des))
28
    return umls_label, umls_des
29
30
31
def main(filename, summary_method, umls_label, umls_des):
32
    try:
33
        config = AutoConfig.from_pretrained(filename)
34
        model = AutoModel.from_pretrained(
35
            filename, config=config).to(device)
36
    except BaseException:
37
        model = torch.load(os.path.join(
38
            filename, 'pytorch_model.bin')).to(device)
39
40
    try:
41
        model.output_hidden_states = False
42
    except BaseException:
43
        pass
44
45
    try:
46
        tokenizer = AutoTokenizer.from_pretrained(filename)
47
    except BaseException:
48
        tokenizer = AutoTokenizer.from_pretrained(
49
            os.path.join(filename, "../"))
50
51
    corpus_list = [("Medline", "es"), ("Medline", "fr"), ("Medline", "nl"), ("Medline", "de"),
52
                   ("EMEA", "es"), ("EMEA", "fr"), ("EMEA", "nl"), ("EMEA", "de"),
53
                   ("Patent", "fr"), ("Patent", "de")]
54
    """
55
    sty_list = ["Geographic Area",
56
                "Drug Delivery Device", "Medical Device", "Research Device",
57
                "Anatomical Abnormality", "Anatomical Structure", "Fully Formed Anatomical Structure",
58
                "Chemical", "Chemical Viewed Functionally", "Chemical Viewed Structurally", "Inorganic Chemical", "Organic Chemical", "Clinical Drug"]
59
    """
60
    result_dict = {}
61
    umls_embedding = get_bert_embed(umls_des, model, tokenizer, summary_method=summary_method, tqdm_bar=True)
62
63
    for corpus in corpus_list:
64
        output_text, output_label, label_set = load(dataset=corpus[0], lang=corpus[1])
65
        not_umls_label = [label for label in label_set if not label in umls_label]
66
        print(f"Count of not appearing in UMLS subset: {len(not_umls_label)}")
67
        text_embedding = get_bert_embed(output_text, model, tokenizer, summary_method=summary_method)
68
        predict_label = predict(text_embedding, umls_embedding, umls_label)
69
        p, r, f1 = metric(output_label, predict_label)
70
        result_dict[corpus[0] + "|" + corpus[1]] = (p, r, f1)
71
        print(p, r, f1)
72
73
    return result_dict
74
75
def predict(text_embedding, umls_embedding, umls_label):
76
    x_size = text_embedding.size(0)
77
    sim = torch.matmul(text_embedding, umls_embedding.t())
78
    most_similar = torch.max(sim, dim=1)[1]
79
    return [umls_label[idx] for idx in most_similar]
80
81
82
def metric(output_label, predict_label):
83
    predict_count = 0
84
    true_count = 0
85
    correct_count = 0
86
    for idx in range(len(output_label)):
87
        if isinstance(predict_label[idx], str):
88
            predict_label[idx] = [predict_label[idx]]
89
        if isinstance(output_label[idx], str):
90
            output_label[idx] = [output_label[idx]]
91
        predict_count += len(predict_label[idx])
92
        true_count += len(output_label[idx])
93
        for pred in predict_label[idx]:
94
            if pred in output_label[idx]:
95
                correct_count += 1
96
97
    p = correct_count / predict_count
98
    r = correct_count / true_count
99
    if p == 0. or r == 0.:
100
        f1 = 0.
101
    else:
102
        f1 = 2 * p * r / (p + r)
103
    return p, r, f1
104
105
106
def get_bert_embed(phrase_list, m, tok, normalize=True, summary_method="CLS", tqdm_bar=False):
107
    input_ids = []
108
    for phrase in phrase_list:
109
        input_ids.append(tok.encode_plus(
110
            phrase, max_length=32, add_special_tokens=True,
111
            truncation=True, pad_to_max_length=True)['input_ids'])
112
    m.eval()
113
114
    count = len(input_ids)
115
    now_count = 0
116
    with torch.no_grad():
117
        if tqdm_bar:
118
            pbar = tqdm.tqdm(total=count)
119
        while now_count < count:
120
            input_gpu_0 = torch.LongTensor(input_ids[now_count:min(
121
                now_count + batch_size, count)]).to(device)
122
            if summary_method == "CLS":
123
                embed = m(input_gpu_0)[1]
124
            if summary_method == "MEAN":
125
                embed = torch.mean(m(input_gpu_0)[0], dim=1)
126
            if normalize:
127
                embed_norm = torch.norm(
128
                    embed, p=2, dim=1, keepdim=True).clamp(min=1e-12)
129
                embed = embed / embed_norm
130
            if now_count == 0:
131
                output = embed
132
            else:
133
                output = torch.cat((output, embed), dim=0)
134
            if tqdm_bar:
135
                pbar.update(min(now_count + batch_size, count) - now_count)
136
            now_count = min(now_count + batch_size, count)
137
        if tqdm_bar:
138
            pbar.close()
139
    return output
140
141
142
if __name__ == '__main__':
143
    umls_label, umls_des = get_umls()
144
    main("bert-base-multilingual-cased", "CLS", umls_label, umls_des)