a b/coderpp/train/load_umls.py
1
import os
2
from tqdm import tqdm
3
import re
4
from random import shuffle
5
import pickle
6
import ahocorasick
7
#import ipdb
8
9
def byLineReader(filename):
10
    with open(filename, "r", encoding="utf-8") as f:
11
        line = f.readline()
12
        while line:
13
            yield line
14
            line = f.readline()
15
    return
16
17
18
class UMLS(object):
19
    def __init__(self, umls_path, phrase2idx_path='data/string2idx.pkl', idx2phrase_path='data/idx2string.pkl', source_range=None, lang_range=['ENG'], only_load_dict=False):
20
        self.umls_path = umls_path
21
        self.source_range = source_range
22
        self.lang_range = lang_range
23
        self.phrase2idx = self._load_pickle(phrase2idx_path)
24
        self.idx2phrase = self._load_pickle(idx2phrase_path)
25
        self.detect_type()
26
        self.load()
27
28
    def _load_pickle(self, path):
29
        with open(path, 'rb') as f:
30
            return pickle.load(f)
31
32
    def transform(self, phrase):
33
        if phrase in self.phrase2idx.keys():
34
            return self.phrase2idx[phrase]
35
        else:
36
            return None
37
38
    def detect_type(self):
39
        if os.path.exists(os.path.join(self.umls_path, "MRCONSO.RRF")):
40
            self.type = "RRF"
41
        else:
42
            self.type = "txt"
43
44
    def load(self):
45
        reader = byLineReader(os.path.join(self.umls_path, "MRCONSO." + self.type))
46
        self.lui_set = set()
47
        self.cui2stridx = {}
48
        self.str2cui = {}
49
        self.code2cui = {}
50
        self.stridx_list = set()
51
        #self.lui_status = {}
52
        read_count = 0
53
        for line in tqdm(reader, ascii=True):
54
            if self.type == "txt":
55
                l = [t.replace("\"", "") for t in line.split(",")]
56
            else:
57
                l = line.strip().split("|")
58
            cui = l[0]
59
            lang = l[1]
60
            # lui_status = l[2].lower() # p -> preferred
61
            lui = l[3]
62
            source = l[11]
63
            code = l[13]
64
            string = l[14]
65
66
            if (self.source_range is None or source in self.source_range) and (self.lang_range is None or lang in self.lang_range):
67
                if not lui in self.lui_set:
68
                    clean_string = self.clean(string)
69
                    idx = self.transform(clean_string)
70
                    if idx is None:
71
                        continue
72
                    read_count += 1
73
                    # if 'abdom' not in clean_string:
74
                    #     continue
75
                    self.str2cui[string] = cui
76
                    self.str2cui[string.lower()] = cui
77
78
                    self.str2cui[clean_string] = cui
79
80
                    if not cui in self.cui2stridx:
81
                        self.cui2stridx[cui] = set()
82
                    self.cui2stridx[cui].update([idx])
83
                    self.stridx_list.update([idx])
84
                    self.code2cui[code] = cui
85
                    self.lui_set.update([lui])
86
87
            # For debug
88
            # if len(self.stridx_list) > 500:
89
            #     break
90
91
        self.cui = list(self.cui2stridx.keys())
92
        shuffle(self.cui)
93
        self.cui_count = len(self.cui)
94
        self.stridx_list = list(self.stridx_list)
95
96
        print("cui count:", self.cui_count)
97
        print("str2cui count:", len(self.str2cui))
98
        print("MRCONSO count:", read_count)
99
        print("str count:", len(self.stridx_list))
100
        # print([[self.idx2phrase[stridx] for stridx in list(gt_clustering)] for gt_clustering in list(self.cui2stridx.values())])
101
102
    def clean(self, term, lower=True, clean_NOS=True, clean_bracket=True, clean_dash=True):
103
        term = " " + term + " "
104
        if lower:
105
            term = term.lower()
106
        if clean_NOS:
107
            term = term.replace(" NOS ", " ").replace(" nos ", " ")
108
        if clean_bracket:
109
            term = re.sub(u"\\(.*?\\)", "", term)
110
        if clean_dash:
111
            term = term.replace("-", " ")
112
        term = " ".join([w for w in term.split() if w])
113
        return term