a b/coderpp/test/load_umls.py
1
import os
2
from tqdm import tqdm
3
import re
4
from random import shuffle
5
import pickle
6
import ahocorasick
7
#import ipdb
8
9
def byLineReader(filename):
10
    with open(filename, "r", encoding="utf-8") as f:
11
        line = f.readline()
12
        while line:
13
            yield line
14
            line = f.readline()
15
    return
16
17
18
class UMLS(object):
19
    def __init__(self, umls_path, phrase2idx_path, idx2phrase_path, source_range=None, lang_range=['ENG'], only_load_dict=False):
20
        # phrase2idx is the dict of our NER vocab. It is used to exclude those phrases in MRCONSO but not in our NER vocab
21
        self.umls_path = umls_path
22
        self.source_range = source_range
23
        self.lang_range = lang_range
24
        self.phrase2idx = self._load_pickle(phrase2idx_path)
25
        self.idx2phrase = self._load_pickle(idx2phrase_path)
26
        self.detect_type()
27
        self.load()
28
29
    def _load_pickle(self, path):
30
        with open(path, 'rb') as f:
31
            return pickle.load(f)
32
33
    def transform(self, phrase):
34
        if phrase in self.phrase2idx.keys() and len(phrase)>3:
35
            return self.phrase2idx[phrase]
36
        else:
37
            return None
38
39
    def detect_type(self):
40
        if os.path.exists(os.path.join(self.umls_path, "MRCONSO.RRF")):
41
            self.type = "RRF"
42
        else:
43
            self.type = "txt"
44
45
    def load(self):
46
        reader = byLineReader(os.path.join(self.umls_path, "MRCONSO." + self.type))
47
        self.lui_set = set()
48
        self.cui2str = {}
49
        self.str2cui = {}
50
        self.code2cui = {}
51
        self.stridx_list = set()
52
        #self.lui_status = {}
53
        read_count = 0
54
        for line in tqdm(reader, ascii=True):
55
            if self.type == "txt":
56
                l = [t.replace("\"", "") for t in line.split(",")]
57
            else:
58
                l = line.strip().split("|")
59
            cui = l[0]
60
            lang = l[1]
61
            # lui_status = l[2].lower() # p -> preferred
62
            lui = l[3]
63
            source = l[11]
64
            code = l[13]
65
            string = l[14]
66
67
            if (self.source_range is None or source in self.source_range) and (self.lang_range is None or lang in self.lang_range):
68
                if not lui in self.lui_set:
69
                    clean_string = self.clean(string)
70
                    idx = self.transform(clean_string)
71
                    if idx is None:
72
                        continue
73
                    read_count += 1
74
                    # if 'abdom' not in clean_string:
75
                    #     continue
76
                    if string not in self.str2cui:
77
                        self.str2cui[string] = set()
78
                    self.str2cui[string].update([cui])
79
                    if string.lower() not in self.str2cui:
80
                        self.str2cui[string.lower()] = set()
81
                    self.str2cui[string.lower()].update([cui])
82
                    if clean_string not in self.str2cui:
83
                        self.str2cui[clean_string] = set()
84
                    self.str2cui[clean_string].update([cui])
85
86
                    if not cui in self.cui2str:
87
                        self.cui2str[cui] = set()
88
                    self.cui2str[cui].update([idx])
89
                    self.stridx_list.update([idx])
90
                    self.code2cui[code] = cui
91
                    self.lui_set.update([lui])
92
93
            # For debug
94
            # if len(self.stridx_list) > 500:
95
            #     break
96
97
        self.cui = list(self.cui2str.keys())
98
        shuffle(self.cui)
99
        self.cui_count = len(self.cui)
100
        self.stridx_list = list(self.stridx_list)
101
102
        print("cui count:", self.cui_count)
103
        print("str2cui count:", len(self.str2cui))
104
        print("MRCONSO count:", read_count)
105
        print("str count:", len(self.stridx_list))
106
        # print([[self.idx2phrase[stridx] for stridx in list(gt_clustering)] for gt_clustering in list(self.cui2str.values())])
107
108
    def clean(self, term, lower=True, clean_NOS=True, clean_bracket=True, clean_dash=True):
109
        term = " " + term + " "
110
        if lower:
111
            term = term.lower()
112
        if clean_NOS:
113
            term = term.replace(" NOS ", " ").replace(" nos ", " ")
114
        if clean_bracket:
115
            term = re.sub(u"\\(.*?\\)", "", term)
116
        if clean_dash:
117
            term = term.replace("-", " ")
118
        term = " ".join([w for w in term.split() if w])
119
        return term
120