a b/coderpp/test/generate_use_data.py
1
# Simply read umls and generate idx2phrase, phrase2idx
2
import pickle
3
import argparse
4
import os
5
from tqdm import tqdm
6
import re
7
from random import shuffle
8
#import ipdb
9
10
def byLineReader(filename):
11
    with open(filename, "r", encoding="utf-8") as f:
12
        line = f.readline()
13
        while line:
14
            yield line
15
            line = f.readline()
16
    return
17
18
class AllUMLS(object):
19
    def __init__(self, umls_path, source_range=None, lang_range=['ENG'], only_load_dict=True):
20
        self.umls_path = umls_path
21
        self.source_range = source_range
22
        self.lang_range = lang_range
23
        self.detect_type()
24
        self.load()
25
        if not only_load_dict:
26
            self.load_rel()
27
            self.load_sty()
28
29
    def detect_type(self):
30
        if os.path.exists(os.path.join(self.umls_path, "MRCONSO.RRF")):
31
            self.type = "RRF"
32
        else:
33
            self.type = "txt"
34
35
    def load(self):
36
        reader = byLineReader(os.path.join(self.umls_path, "MRCONSO." + self.type))
37
        self.lui_set = set()
38
        self.cui2str = {}
39
        self.str2cui = {}
40
        self.code2cui = {}
41
        #self.lui_status = {}
42
        read_count = 0
43
        for line in tqdm(reader, ascii=True):
44
            if self.type == "txt":
45
                l = [t.replace("\"", "") for t in line.split(",")]
46
            else:
47
                l = line.strip().split("|")
48
            cui = l[0]
49
            lang = l[1]
50
            # lui_status = l[2].lower() # p -> preferred
51
            lui = l[3]
52
            source = l[11]
53
            code = l[13]
54
            string = l[14]
55
56
            if (self.source_range is None or source in self.source_range) and (self.lang_range is None or lang in self.lang_range):
57
                if not lui in self.lui_set:
58
                    read_count += 1
59
                    # self.str2cui[string] = cui
60
                    # self.str2cui[string.lower()] = cui
61
                    clean_string = self.clean(string)
62
                    self.str2cui[clean_string] = cui
63
64
                    if not cui in self.cui2str:
65
                        self.cui2str[cui] = set()
66
                    self.cui2str[cui].update([clean_string])
67
                    self.code2cui[code] = cui
68
                    self.lui_set.update([lui])
69
70
            # For debug
71
            # if read_count > 1000:
72
            #     break
73
74
        self.cui = list(self.cui2str.keys())
75
        self.cui_count = len(self.cui)
76
77
        print("cui count:", self.cui_count)
78
        print("str2cui count:", len(self.str2cui))
79
        print("MRCONSO count:", read_count)
80
81
    def load_rel(self):
82
        reader = byLineReader(os.path.join(self.umls_path, "MRREL." + self.type))
83
        self.rel = set()
84
        for line in tqdm(reader, ascii=True):
85
            if self.type == "txt":
86
                l = [t.replace("\"", "") for t in line.split(",")]
87
            else:
88
                l = line.strip().split("|")
89
            cui0 = l[0]
90
            re = l[3]
91
            cui1 = l[4]
92
            rel = l[7]
93
            if cui0 in self.cui2str and cui1 in self.cui2str:
94
                str_rel = "\t".join([cui0, cui1, re, rel])
95
                if not str_rel in self.rel and cui0 != cui1:
96
                    self.rel.update([str_rel])
97
98
            # For debug
99
            # if len(self.rel) > 1000:
100
            #     break
101
        self.rel = list(self.rel)
102
103
        print("rel count:", len(self.rel))
104
105
    def load_sty(self):
106
        reader = byLineReader(os.path.join(self.umls_path, "MRSTY." + self.type))
107
        self.cui2sty = {}
108
        for line in tqdm(reader, ascii=True):
109
            if self.type == "txt":
110
                l = [t.replace("\"", "") for t in line.split(",")]
111
            else:
112
                l = line.strip().split("|")
113
            cui = l[0]
114
            sty = l[3]
115
            if cui in self.cui2str:
116
                self.cui2sty[cui] = sty
117
118
        print("sty count:", len(self.cui2sty))
119
120
    def clean(self, term, lower=True, clean_NOS=True, clean_bracket=True, clean_dash=True):
121
        term = " " + term + " "
122
        if lower:
123
            term = term.lower()
124
        if clean_NOS:
125
            term = term.replace(" NOS ", " ").replace(" nos ", " ")
126
        if clean_bracket:
127
            term = re.sub(u"\\(.*?\\)", "", term)
128
        if clean_dash:
129
            term = term.replace("-", " ")
130
        term = " ".join([w for w in term.split() if w])
131
        return term
132
133
    def search_by_code(self, code):
134
        if code in self.cui2str:
135
            return list(self.cui2str[code])
136
        if code in self.code2cui:
137
            return list(self.cui2str[self.code2cui[code]])
138
        return None
139
140
    def search_by_string_list(self, string_list):
141
        for string in string_list:
142
            if string in self.str2cui:
143
                find_string = self.cui2str[self.str2cui[string]]
144
                return [string for string in find_string if not string in string_list]
145
            if string.lower() in self.str2cui:
146
                find_string = self.cui2str[self.str2cui[string.lower()]]
147
                return [string for string in find_string if not string in string_list]
148
        return None
149
150
    def search(self, code=None, string_list=None, max_number=-1):
151
        result_by_code = self.search_by_code(code)
152
        if result_by_code is not None:
153
            if max_number > 0:
154
                return result_by_code[0:min(len(result_by_code), max_number)]
155
            return result_by_code
156
        return None
157
        result_by_string = self.search_by_string_list(string_list)
158
        if result_by_string is not None:
159
            if max_number > 0:
160
                return result_by_string[0:min(len(result_by_string), max_number)]
161
            return result_by_string
162
        return None
163
164
165
if __name__ == '__main__':
166
    parser = argparse.ArgumentParser()
167
    parser.add_argument(
168
        "--umls_dir",
169
        default="../umls",
170
        type=str,
171
        help="UMLS dir",
172
    )
173
    parser.add_argument(
174
        "--use_data_dir",
175
        default="data",
176
        type=str,
177
        help="output dir for data like idx2phrase and phrase2idx"
178
    )
179
    args = parser.parse_args()
180
181
    umls = AllUMLS(
182
        umls_path=args.umls_dir
183
    )
184
    phrase_list = list(umls.str2cui.keys())
185
186
    idx2phrase = {idx:phrase for idx, phrase in enumerate(phrase_list)}
187
    print(type(idx2phrase))
188
    phrase2idx = {phrase:idx for idx, phrase in enumerate(phrase_list)}
189
    with open(os.path.join(args.use_data_dir, 'idx2phrase.pkl'), 'wb') as f:
190
        pickle.dump(idx2phrase, f)
191
    with open(os.path.join(args.use_data_dir, 'phrase2idx.pkl'), 'wb') as f:
192
        pickle.dump(phrase2idx, f)