|
a |
|
b/coderpp/train/generate_use_data.py |
|
|
1 |
# Simply read umls and generate idx2phrase, phrase2idx |
|
|
2 |
import pickle |
|
|
3 |
import argparse |
|
|
4 |
import os |
|
|
5 |
from tqdm import tqdm |
|
|
6 |
import re |
|
|
7 |
from random import shuffle |
|
|
8 |
#import ipdb |
|
|
9 |
|
|
|
10 |
def byLineReader(filename): |
|
|
11 |
with open(filename, "r", encoding="utf-8") as f: |
|
|
12 |
line = f.readline() |
|
|
13 |
while line: |
|
|
14 |
yield line |
|
|
15 |
line = f.readline() |
|
|
16 |
return |
|
|
17 |
|
|
|
18 |
class AllUMLS(object): |
|
|
19 |
def __init__(self, umls_path, source_range=None, lang_range=['ENG'], only_load_dict=True): |
|
|
20 |
self.umls_path = umls_path |
|
|
21 |
self.source_range = source_range |
|
|
22 |
self.lang_range = lang_range |
|
|
23 |
self.detect_type() |
|
|
24 |
self.load() |
|
|
25 |
if not only_load_dict: |
|
|
26 |
self.load_rel() |
|
|
27 |
self.load_sty() |
|
|
28 |
|
|
|
29 |
def detect_type(self): |
|
|
30 |
if os.path.exists(os.path.join(self.umls_path, "MRCONSO.RRF")): |
|
|
31 |
self.type = "RRF" |
|
|
32 |
else: |
|
|
33 |
self.type = "txt" |
|
|
34 |
|
|
|
35 |
def load(self): |
|
|
36 |
reader = byLineReader(os.path.join(self.umls_path, "MRCONSO." + self.type)) |
|
|
37 |
self.lui_set = set() |
|
|
38 |
self.cui2str = {} |
|
|
39 |
self.str2cui = {} |
|
|
40 |
self.code2cui = {} |
|
|
41 |
#self.lui_status = {} |
|
|
42 |
read_count = 0 |
|
|
43 |
for line in tqdm(reader, ascii=True): |
|
|
44 |
if self.type == "txt": |
|
|
45 |
l = [t.replace("\"", "") for t in line.split(",")] |
|
|
46 |
else: |
|
|
47 |
l = line.strip().split("|") |
|
|
48 |
cui = l[0] |
|
|
49 |
lang = l[1] |
|
|
50 |
# lui_status = l[2].lower() # p -> preferred |
|
|
51 |
lui = l[3] |
|
|
52 |
source = l[11] |
|
|
53 |
code = l[13] |
|
|
54 |
string = l[14] |
|
|
55 |
|
|
|
56 |
if (self.source_range is None or source in self.source_range) and (self.lang_range is None or lang in self.lang_range): |
|
|
57 |
if not lui in self.lui_set: |
|
|
58 |
read_count += 1 |
|
|
59 |
# self.str2cui[string] = cui |
|
|
60 |
# self.str2cui[string.lower()] = cui |
|
|
61 |
clean_string = self.clean(string) |
|
|
62 |
self.str2cui[clean_string] = cui |
|
|
63 |
|
|
|
64 |
if not cui in self.cui2str: |
|
|
65 |
self.cui2str[cui] = set() |
|
|
66 |
self.cui2str[cui].update([clean_string]) |
|
|
67 |
self.code2cui[code] = cui |
|
|
68 |
self.lui_set.update([lui]) |
|
|
69 |
|
|
|
70 |
# For debug |
|
|
71 |
# if read_count > 1000: |
|
|
72 |
# break |
|
|
73 |
|
|
|
74 |
self.cui = list(self.cui2str.keys()) |
|
|
75 |
self.cui_count = len(self.cui) |
|
|
76 |
|
|
|
77 |
print("cui count:", self.cui_count) |
|
|
78 |
print("str2cui count:", len(self.str2cui)) |
|
|
79 |
print("MRCONSO count:", read_count) |
|
|
80 |
|
|
|
81 |
def load_rel(self): |
|
|
82 |
reader = byLineReader(os.path.join(self.umls_path, "MRREL." + self.type)) |
|
|
83 |
self.rel = set() |
|
|
84 |
for line in tqdm(reader, ascii=True): |
|
|
85 |
if self.type == "txt": |
|
|
86 |
l = [t.replace("\"", "") for t in line.split(",")] |
|
|
87 |
else: |
|
|
88 |
l = line.strip().split("|") |
|
|
89 |
cui0 = l[0] |
|
|
90 |
re = l[3] |
|
|
91 |
cui1 = l[4] |
|
|
92 |
rel = l[7] |
|
|
93 |
if cui0 in self.cui2str and cui1 in self.cui2str: |
|
|
94 |
str_rel = "\t".join([cui0, cui1, re, rel]) |
|
|
95 |
if not str_rel in self.rel and cui0 != cui1: |
|
|
96 |
self.rel.update([str_rel]) |
|
|
97 |
|
|
|
98 |
# For debug |
|
|
99 |
# if len(self.rel) > 1000: |
|
|
100 |
# break |
|
|
101 |
self.rel = list(self.rel) |
|
|
102 |
|
|
|
103 |
print("rel count:", len(self.rel)) |
|
|
104 |
|
|
|
105 |
def load_sty(self): |
|
|
106 |
reader = byLineReader(os.path.join(self.umls_path, "MRSTY." + self.type)) |
|
|
107 |
self.cui2sty = {} |
|
|
108 |
for line in tqdm(reader, ascii=True): |
|
|
109 |
if self.type == "txt": |
|
|
110 |
l = [t.replace("\"", "") for t in line.split(",")] |
|
|
111 |
else: |
|
|
112 |
l = line.strip().split("|") |
|
|
113 |
cui = l[0] |
|
|
114 |
sty = l[3] |
|
|
115 |
if cui in self.cui2str: |
|
|
116 |
self.cui2sty[cui] = sty |
|
|
117 |
|
|
|
118 |
print("sty count:", len(self.cui2sty)) |
|
|
119 |
|
|
|
120 |
def clean(self, term, lower=True, clean_NOS=True, clean_bracket=True, clean_dash=True): |
|
|
121 |
term = " " + term + " " |
|
|
122 |
if lower: |
|
|
123 |
term = term.lower() |
|
|
124 |
if clean_NOS: |
|
|
125 |
term = term.replace(" NOS ", " ").replace(" nos ", " ") |
|
|
126 |
if clean_bracket: |
|
|
127 |
term = re.sub(u"\\(.*?\\)", "", term) |
|
|
128 |
if clean_dash: |
|
|
129 |
term = term.replace("-", " ") |
|
|
130 |
term = " ".join([w for w in term.split() if w]) |
|
|
131 |
return term |
|
|
132 |
|
|
|
133 |
def search_by_code(self, code): |
|
|
134 |
if code in self.cui2str: |
|
|
135 |
return list(self.cui2str[code]) |
|
|
136 |
if code in self.code2cui: |
|
|
137 |
return list(self.cui2str[self.code2cui[code]]) |
|
|
138 |
return None |
|
|
139 |
|
|
|
140 |
def search_by_string_list(self, string_list): |
|
|
141 |
for string in string_list: |
|
|
142 |
if string in self.str2cui: |
|
|
143 |
find_string = self.cui2str[self.str2cui[string]] |
|
|
144 |
return [string for string in find_string if not string in string_list] |
|
|
145 |
if string.lower() in self.str2cui: |
|
|
146 |
find_string = self.cui2str[self.str2cui[string.lower()]] |
|
|
147 |
return [string for string in find_string if not string in string_list] |
|
|
148 |
return None |
|
|
149 |
|
|
|
150 |
def search(self, code=None, string_list=None, max_number=-1): |
|
|
151 |
result_by_code = self.search_by_code(code) |
|
|
152 |
if result_by_code is not None: |
|
|
153 |
if max_number > 0: |
|
|
154 |
return result_by_code[0:min(len(result_by_code), max_number)] |
|
|
155 |
return result_by_code |
|
|
156 |
return None |
|
|
157 |
result_by_string = self.search_by_string_list(string_list) |
|
|
158 |
if result_by_string is not None: |
|
|
159 |
if max_number > 0: |
|
|
160 |
return result_by_string[0:min(len(result_by_string), max_number)] |
|
|
161 |
return result_by_string |
|
|
162 |
return None |
|
|
163 |
|
|
|
164 |
|
|
|
165 |
if __name__ == '__main__': |
|
|
166 |
parser = argparse.ArgumentParser() |
|
|
167 |
parser.add_argument( |
|
|
168 |
"--umls_dir", |
|
|
169 |
default="../umls", |
|
|
170 |
type=str, |
|
|
171 |
help="UMLS dir", |
|
|
172 |
) |
|
|
173 |
parser.add_argument( |
|
|
174 |
"--use_data_dir", |
|
|
175 |
default="data", |
|
|
176 |
type=str, |
|
|
177 |
help="output dir for data like idx2phrase and phrase2idx" |
|
|
178 |
) |
|
|
179 |
args = parser.parse_args() |
|
|
180 |
|
|
|
181 |
umls = AllUMLS( |
|
|
182 |
umls_path=args.umls_dir |
|
|
183 |
) |
|
|
184 |
phrase_list = list(umls.str2cui.keys()) |
|
|
185 |
|
|
|
186 |
idx2phrase = {idx:phrase for idx, phrase in enumerate(phrase_list)} |
|
|
187 |
print(type(idx2phrase)) |
|
|
188 |
phrase2idx = {phrase:idx for idx, phrase in enumerate(phrase_list)} |
|
|
189 |
with open(os.path.join(args.use_data_dir, 'idx2phrase.pkl'), 'wb') as f: |
|
|
190 |
pickle.dump(idx2phrase, f) |
|
|
191 |
with open(os.path.join(args.use_data_dir, 'phrase2idx.pkl'), 'wb') as f: |
|
|
192 |
pickle.dump(phrase2idx, f) |