Diff of /pretrain/load_umls.py [000000] .. [c3444c]

Switch to unified view

a b/pretrain/load_umls.py
1
import os
2
from tqdm import tqdm
3
import re
4
from random import shuffle
5
#import ipdb
6
7
def byLineReader(filename):
8
    with open(filename, "r", encoding="utf-8") as f:
9
        line = f.readline()
10
        while line:
11
            yield line
12
            line = f.readline()
13
    return
14
15
16
class UMLS(object):
17
    def __init__(self, umls_path, source_range=None, lang_range=['ENG'], only_load_dict=False):
18
        self.umls_path = umls_path
19
        self.source_range = source_range
20
        self.lang_range = lang_range
21
        self.detect_type()
22
        self.load()
23
        if not only_load_dict:
24
            self.load_rel()
25
            self.load_sty()
26
27
    def detect_type(self):
28
        if os.path.exists(os.path.join(self.umls_path, "MRCONSO.RRF")):
29
            self.type = "RRF"
30
        else:
31
            self.type = "txt"
32
33
    def load(self):
34
        reader = byLineReader(os.path.join(self.umls_path, "MRCONSO." + self.type))
35
        self.lui_set = set()
36
        self.cui2str = {}
37
        self.str2cui = {}
38
        self.code2cui = {}
39
        #self.lui_status = {}
40
        read_count = 0
41
        for line in tqdm(reader, ascii=True):
42
            if self.type == "txt":
43
                l = [t.replace("\"", "") for t in line.split(",")]
44
            else:
45
                l = line.strip().split("|")
46
            cui = l[0]
47
            lang = l[1]
48
            # lui_status = l[2].lower() # p -> preferred
49
            lui = l[3]
50
            source = l[11]
51
            code = l[13]
52
            string = l[14]
53
54
            if (self.source_range is None or source in self.source_range) and (self.lang_range is None or lang in self.lang_range):
55
                if not lui in self.lui_set:
56
                    read_count += 1
57
                    self.str2cui[string] = cui
58
                    self.str2cui[string.lower()] = cui
59
                    clean_string = self.clean(string)
60
                    self.str2cui[clean_string] = cui
61
62
                    if not cui in self.cui2str:
63
                        self.cui2str[cui] = set()
64
                    self.cui2str[cui].update([clean_string])
65
                    self.code2cui[code] = cui
66
                    self.lui_set.update([lui])
67
68
            # For debug
69
            # if read_count > 1000:
70
            #     break
71
72
        self.cui = list(self.cui2str.keys())
73
        shuffle(self.cui)
74
        self.cui_count = len(self.cui)
75
76
        print("cui count:", self.cui_count)
77
        print("str2cui count:", len(self.str2cui))
78
        print("MRCONSO count:", read_count)
79
80
    def load_rel(self):
81
        reader = byLineReader(os.path.join(self.umls_path, "MRREL." + self.type))
82
        self.rel = set()
83
        for line in tqdm(reader, ascii=True):
84
            if self.type == "txt":
85
                l = [t.replace("\"", "") for t in line.split(",")]
86
            else:
87
                l = line.strip().split("|")
88
            cui0 = l[0]
89
            re = l[3]
90
            cui1 = l[4]
91
            rel = l[7]
92
            if cui0 in self.cui2str and cui1 in self.cui2str:
93
                str_rel = "\t".join([cui0, cui1, re, rel])
94
                if not str_rel in self.rel and cui0 != cui1:
95
                    self.rel.update([str_rel])
96
97
            # For debug
98
            # if len(self.rel) > 1000:
99
            #     break
100
        self.rel = list(self.rel)
101
102
        print("rel count:", len(self.rel))
103
104
    def load_sty(self):
105
        reader = byLineReader(os.path.join(self.umls_path, "MRSTY." + self.type))
106
        self.cui2sty = {}
107
        for line in tqdm(reader, ascii=True):
108
            if self.type == "txt":
109
                l = [t.replace("\"", "") for t in line.split(",")]
110
            else:
111
                l = line.strip().split("|")
112
            cui = l[0]
113
            sty = l[3]
114
            if cui in self.cui2str:
115
                self.cui2sty[cui] = sty
116
117
        print("sty count:", len(self.cui2sty))
118
119
    def clean(self, term, lower=True, clean_NOS=True, clean_bracket=True, clean_dash=True):
120
        term = " " + term + " "
121
        if lower:
122
            term = term.lower()
123
        if clean_NOS:
124
            term = term.replace(" NOS ", " ").replace(" nos ", " ")
125
        if clean_bracket:
126
            term = re.sub(u"\\(.*?\\)", "", term)
127
        if clean_dash:
128
            term = term.replace("-", " ")
129
        term = " ".join([w for w in term.split() if w])
130
        return term
131
132
    def search_by_code(self, code):
133
        if code in self.cui2str:
134
            return list(self.cui2str[code])
135
        if code in self.code2cui:
136
            return list(self.cui2str[self.code2cui[code]])
137
        return None
138
139
    def search_by_string_list(self, string_list):
140
        for string in string_list:
141
            if string in self.str2cui:
142
                find_string = self.cui2str[self.str2cui[string]]
143
                return [string for string in find_string if not string in string_list]
144
            if string.lower() in self.str2cui:
145
                find_string = self.cui2str[self.str2cui[string.lower()]]
146
                return [string for string in find_string if not string in string_list]
147
        return None
148
149
    def search(self, code=None, string_list=None, max_number=-1):
150
        result_by_code = self.search_by_code(code)
151
        if result_by_code is not None:
152
            if max_number > 0:
153
                return result_by_code[0:min(len(result_by_code), max_number)]
154
            return result_by_code
155
        return None
156
        result_by_string = self.search_by_string_list(string_list)
157
        if result_by_string is not None:
158
            if max_number > 0:
159
                return result_by_string[0:min(len(result_by_string), max_number)]
160
            return result_by_string
161
        return None
162
163
164
if __name__ == "__main__":
165
    umls = UMLS("E:\\code\\research\\umls")
166
    # print(umls.search_by_code("282299006"))
167
    #print(umls.search_by_string_list(["Backache", "aching muscles in back"]))
168
    #print(umls.search(code="95891005", max_number=10))
169
    # ipdb.set_trace()
170
171
"""
172
['unable to balance', 'loss of balance']
173
['backache', 'back pain', 'dorsalgi', 'dorsodynia', 'pain over the back', 'back pain [disease/finding]', 'back ache', 'dorsal back pain', 'backach', 'dorsalgia', 'dorsal pain', 'notalgia', 'unspecified back pain', 'backpain', 'backache symptom']
174
['influenza like illness', 'flu-like illness', 'influenza-like illness']
175
"""