--- a
+++ b/pretrain/load_umls.py
@@ -0,0 +1,175 @@
+import os
+from tqdm import tqdm
+import re
+from random import shuffle
+#import ipdb
+
+def byLineReader(filename):
+    with open(filename, "r", encoding="utf-8") as f:
+        line = f.readline()
+        while line:
+            yield line
+            line = f.readline()
+    return
+
+
+class UMLS(object):
+    def __init__(self, umls_path, source_range=None, lang_range=['ENG'], only_load_dict=False):
+        self.umls_path = umls_path
+        self.source_range = source_range
+        self.lang_range = lang_range
+        self.detect_type()
+        self.load()
+        if not only_load_dict:
+            self.load_rel()
+            self.load_sty()
+
+    def detect_type(self):
+        if os.path.exists(os.path.join(self.umls_path, "MRCONSO.RRF")):
+            self.type = "RRF"
+        else:
+            self.type = "txt"
+
+    def load(self):
+        reader = byLineReader(os.path.join(self.umls_path, "MRCONSO." + self.type))
+        self.lui_set = set()
+        self.cui2str = {}
+        self.str2cui = {}
+        self.code2cui = {}
+        #self.lui_status = {}
+        read_count = 0
+        for line in tqdm(reader, ascii=True):
+            if self.type == "txt":
+                l = [t.replace("\"", "") for t in line.split(",")]
+            else:
+                l = line.strip().split("|")
+            cui = l[0]
+            lang = l[1]
+            # lui_status = l[2].lower() # p -> preferred
+            lui = l[3]
+            source = l[11]
+            code = l[13]
+            string = l[14]
+
+            if (self.source_range is None or source in self.source_range) and (self.lang_range is None or lang in self.lang_range):
+                if not lui in self.lui_set:
+                    read_count += 1
+                    self.str2cui[string] = cui
+                    self.str2cui[string.lower()] = cui
+                    clean_string = self.clean(string)
+                    self.str2cui[clean_string] = cui
+
+                    if not cui in self.cui2str:
+                        self.cui2str[cui] = set()
+                    self.cui2str[cui].update([clean_string])
+                    self.code2cui[code] = cui
+                    self.lui_set.update([lui])
+
+            # For debug
+            # if read_count > 1000:
+            #     break
+
+        self.cui = list(self.cui2str.keys())
+        shuffle(self.cui)
+        self.cui_count = len(self.cui)
+
+        print("cui count:", self.cui_count)
+        print("str2cui count:", len(self.str2cui))
+        print("MRCONSO count:", read_count)
+
+    def load_rel(self):
+        reader = byLineReader(os.path.join(self.umls_path, "MRREL." + self.type))
+        self.rel = set()
+        for line in tqdm(reader, ascii=True):
+            if self.type == "txt":
+                l = [t.replace("\"", "") for t in line.split(",")]
+            else:
+                l = line.strip().split("|")
+            cui0 = l[0]
+            re = l[3]
+            cui1 = l[4]
+            rel = l[7]
+            if cui0 in self.cui2str and cui1 in self.cui2str:
+                str_rel = "\t".join([cui0, cui1, re, rel])
+                if not str_rel in self.rel and cui0 != cui1:
+                    self.rel.update([str_rel])
+
+            # For debug
+            # if len(self.rel) > 1000:
+            #     break
+        self.rel = list(self.rel)
+
+        print("rel count:", len(self.rel))
+
+    def load_sty(self):
+        reader = byLineReader(os.path.join(self.umls_path, "MRSTY." + self.type))
+        self.cui2sty = {}
+        for line in tqdm(reader, ascii=True):
+            if self.type == "txt":
+                l = [t.replace("\"", "") for t in line.split(",")]
+            else:
+                l = line.strip().split("|")
+            cui = l[0]
+            sty = l[3]
+            if cui in self.cui2str:
+                self.cui2sty[cui] = sty
+
+        print("sty count:", len(self.cui2sty))
+
+    def clean(self, term, lower=True, clean_NOS=True, clean_bracket=True, clean_dash=True):
+        term = " " + term + " "
+        if lower:
+            term = term.lower()
+        if clean_NOS:
+            term = term.replace(" NOS ", " ").replace(" nos ", " ")
+        if clean_bracket:
+            term = re.sub(u"\\(.*?\\)", "", term)
+        if clean_dash:
+            term = term.replace("-", " ")
+        term = " ".join([w for w in term.split() if w])
+        return term
+
+    def search_by_code(self, code):
+        if code in self.cui2str:
+            return list(self.cui2str[code])
+        if code in self.code2cui:
+            return list(self.cui2str[self.code2cui[code]])
+        return None
+
+    def search_by_string_list(self, string_list):
+        for string in string_list:
+            if string in self.str2cui:
+                find_string = self.cui2str[self.str2cui[string]]
+                return [string for string in find_string if not string in string_list]
+            if string.lower() in self.str2cui:
+                find_string = self.cui2str[self.str2cui[string.lower()]]
+                return [string for string in find_string if not string in string_list]
+        return None
+
+    def search(self, code=None, string_list=None, max_number=-1):
+        result_by_code = self.search_by_code(code)
+        if result_by_code is not None:
+            if max_number > 0:
+                return result_by_code[0:min(len(result_by_code), max_number)]
+            return result_by_code
+        return None
+        result_by_string = self.search_by_string_list(string_list)
+        if result_by_string is not None:
+            if max_number > 0:
+                return result_by_string[0:min(len(result_by_string), max_number)]
+            return result_by_string
+        return None
+
+
+if __name__ == "__main__":
+    umls = UMLS("E:\\code\\research\\umls")
+    # print(umls.search_by_code("282299006"))
+    #print(umls.search_by_string_list(["Backache", "aching muscles in back"]))
+    #print(umls.search(code="95891005", max_number=10))
+    # ipdb.set_trace()
+
+"""
+['unable to balance', 'loss of balance']
+['backache', 'back pain', 'dorsalgi', 'dorsodynia', 'pain over the back', 'back pain [disease/finding]', 'back ache', 'dorsal back pain', 'backach', 'dorsalgia', 'dorsal pain', 'notalgia', 'unspecified back pain', 'backpain', 'backache symptom']
+['influenza like illness', 'flu-like illness', 'influenza-like illness']
+"""