a b/ehrkit/ehrkit.py
1
from datetime import date
2
import pymysql
3
#from sshtunnel import SSHTunnelForwarder
4
from ehrkit.classes import Patient, Disease, Diagnosis, Prescription, Procedure
5
from ehrkit.solr_lib import *
6
from datetime import datetime
7
from nltk.tokenize import sent_tokenize, word_tokenize
8
from nltk.corpus import stopwords
9
from gensim import corpora, models, similarities
10
from collections import defaultdict
11
import re
12
import sys
13
import os
14
import pprint
15
import string
16
import torch
17
import requests
18
from sklearn.feature_extraction.text import TfidfVectorizer
19
20
dir_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
21
sys.path.append(dir_path)
22
from scripts.train_word2vec import train_word2vec
23
from scripts.abb_extraction import output_abb
24
25
26
# TODO: adding external library
27
import torch
28
from transformers import AutoTokenizer, AutoModelWithLMHead
29
30
class ehr_db:
31
    """Connection object to Tangra MySQL Server.
32
33
    Attributes:
34
        cnx: pymysql connection object
35
        cur: pymysql cursor object
36
    """
37
38
    def __init__(self, sess):
39
        self.cnx = sess['cnx']
40
        self.cur = sess['cur']
41
        self.patients = {}
42
        self.note_event_flag = False
43
44
45
    def get_patients(self, n):
46
        """Retrieves n patient objects from the database, adds them to self.patients
47
48
        Note:
49
            Patient sorted by ROW_ID in database
50
        Note:
51
            If n == -1, returns all patients.
52
        Args:
53
            n (int): Number of patient objects to return
54
        Returns:
55
            none
56
        """
57
        if n == -1:
58
            self.cur.execute("SELECT SUBJECT_ID, GENDER, DOB, DOD FROM mimic.PATIENTS")
59
        else:
60
            self.cur.execute("SELECT SUBJECT_ID, GENDER, DOB, DOD FROM mimic.PATIENTS LIMIT %d" % n)
61
        raw = self.cur.fetchall()
62
63
64
        for p in raw:
65
            data = {}
66
            data["id"] = p[0]
67
            data["sex"] = p[1]
68
            data["dob"] = p[2]
69
70
            # QUESTION: why use %Y and not %y? %Y only holds last two digits of year. How to tell difference between 100yo patient and newborn?
71
            if data["dob"] != None and isinstance(data["dob"], str):
72
                data["dob"] = datetime.strptime(data["dob"][0:10], "%Y-%m-%d")
73
74
            data["dod"] = p[3]
75
76
            if data["dod"] != None and isinstance(data["dod"], str):
77
                data["dod"] = datetime.strptime(data["dod"][0:10], "%Y-%m-%d")
78
79
            data["alive"] = (data["dod"] == None)
80
81
            self.patients[data["id"]] = Patient(data)
82
83
    def count_patients(self):
84
        '''Counts and returns the number of patients as an int in the database.'''
85
86
        self.cur.execute("SELECT COUNT(*) FROM mimic.PATIENTS")
87
        raw = self.cur.fetchall()
88
        return int(raw[0][0])
89
90
    def count_docs(self, query, getAll = False, inverted = False):
91
        '''
92
        returns document count of tables
93
        query is a list of table names
94
        setting getAll to true returns count of all rows in all tables
95
        setting inverted = False returns count of rows in tables specified in *args
96
        setting inverted = True returns count of rows in all tables except those specified in *args
97
        '''
98
        table_count = self.cur.execute("SELECT TABLE_NAME, TABLE_ROWS from information_schema.tables where TABLE_SCHEMA = 'mimic' ")
99
        numtup = self.cur.fetchall()
100
        #numtup(nested tuple) structure: ((TABLE_NAME(str), TABLE_ROWS(int)),...)
101
        count = 0
102
        if getAll:
103
            for i in range(table_count):
104
                count = count + numtup[i][1]
105
            return count
106
        if inverted:
107
            for i in range(table_count):
108
                if numtup[i][0] in query:
109
                    continue
110
                count = count+numtup[i][1]
111
            return count
112
        for i in range(table_count):
113
            if numtup[i][0] in query:
114
                count = count+numtup[i][1]
115
        return count
116
117
118
    #is this redundant?
119
120
    def get_note_events(self):
121
        """
122
        adds note_events to patient objects in self.patients
123
        depends on get_patients(have to call it first to populate ehrdb with patients)
124
        return: None
125
        """
126
        #TODO: Currently only adds one NoteEvent
127
        for patient in self.patients.values():
128
            if patient.note_events is None:
129
                self.cur.execute("select ROW_ID, TEXT from mimic.NOTEEVENTS where SUBJECT_ID = %d" %patient.id)
130
                rawt = self.cur.fetchall()
131
                ls = []
132
                for p in rawt:
133
                    sent_list = sent_tokenize(p[1])
134
                    ls.append((p[0],sent_list))
135
                    patient.addNE(ls)
136
        self.note_event_flag = True
137
138
    def longest_NE(self):
139
        '''
140
        returns the longest note event in the patient dict
141
        '''
142
        #TODO: Currently only considers one NoteEvent per patient
143
        maxpid, maxlen = None, 0
144
        for patient in self.patients.values():
145
            for doc in patient.note_events:
146
                pid = patient.id
147
                rowid = doc[0]
148
                leng = len(doc[1])
149
                if leng>maxlen:
150
                    maxlen = leng
151
                    maxpid = pid
152
                    maxrowid = rowid
153
        return maxpid, maxrowid, maxlen
154
155
    def get_document(self, id):
156
        """Returns the text of a specific patient record given the ID (row ID in NOTEEVENTS).
157
        """
158
        text = ""
159
        self.cur.execute("select TEXT from mimic.NOTEEVENTS where ROW_ID = %d" % id)
160
        text = self.cur.fetchall()
161
        return text[0][0]
162
163
    def get_all_patient_document_ids(self, patientID):
164
165
        """Returns a list of all document IDs associated with patientID.
166
        """
167
        records = []
168
        self.cur.execute("select ROW_ID from mimic.NOTEEVENTS where SUBJECT_ID = %d" % patientID)
169
        records = self.cur.fetchall()
170
        return flatten(records)
171
172
    def list_all_patient_ids(self):
173
        """Returns a list of all patient IDs in the database.
174
        """
175
        ids = []
176
        self.cur.execute("select SUBJECT_ID from mimic.PATIENTS")
177
        ids = self.cur.fetchall()
178
        return flatten(ids)
179
180
    def list_all_document_ids(self):
181
182
        """Returns a list of all document IDs in the database.
183
        """
184
        ids = []
185
        self.cur.execute("select ROW_ID from mimic.NOTEEVENTS")
186
        ids = self.cur.fetchall()
187
        return flatten(ids)
188
189
    def get_document_sents(self, docID):
190
191
        """Returns list of sentences in a record.
192
        """
193
        self.cur.execute("select TEXT from mimic.NOTEEVENTS where ROW_ID = %d" % docID)
194
        raw = self.cur.fetchall()
195
        sent_list = sent_tokenize(raw[0][0])
196
        if not sent_list:
197
            print("No document text found.")
198
        return sent_list
199
200
    def get_abbreviations(self, doc_id):
201
        ''' Returns a list of the abbreviations in a document.
202
        '''
203
        sent_list = self.get_document_sents(doc_id)
204
        abb_list = set()
205
        for sent in sent_list:
206
            for word in word_tokenize(sent):
207
                pattern = r'[A-Z]{2}'
208
                if re.match(pattern, word):
209
                    abb_list.add(word)
210
211
        return list(abb_list)
212
213
    def get_abbreviation_sent_ids(self, doc_id):
214
        ''' Returns a list of the abbreviations in a document along with the sentence ID they appear in
215
            in the format [(abbreviation, sent_id), ...]
216
        '''
217
218
        sent_list = self.get_document_sents(doc_id)
219
        abb_list = []
220
        for i, sent in zip(range(0, len(sent_list)), sent_list):
221
            for word in word_tokenize(sent):
222
                pattern = r'[A-Z]{2}'
223
                if re.match(pattern, word):
224
                    abb_list.append((word, i))
225
226
        return list(abb_list)
227
228
229
    def get_documents_d(self, date):
230
        """Returns a list of all document IDs recorded on date. Format of YYYY-MM-DD for date.
231
        """
232
        ids = []
233
        self.cur.execute("select ROW_ID from mimic.NOTEEVENTS where CHARTDATE = \"%s\"" % date)
234
        ids = self.cur.fetchall()
235
        if not ids:
236
            print("No values returned. Note that date must be formatted YYYY-MM-DD.")
237
        return flatten(ids)
238
239
    def get_documents_q(self, query, n = -1):
240
        """returns a List of all document IDs that include this text:”Service: SURGERY”
241
            when n = -1, search against all getDocuments
242
        """
243
        query = "%"+query+"%"
244
        ids = []
245
        if n == -1:
246
            self.cur.execute("select ROW_ID from mimic.NOTEEVENTS where TEXT like \'%s\'" %query)
247
        else:
248
            self.cur.execute("select ROW_ID from mimic.NOTEEVENTS where TEXT like \'%s\' limit %d" %(query,n))
249
        ids = self.cur.fetchall() #tuples?, TODO: try Dict Server?
250
        if not ids:
251
            print("No values returned. Note that the query must be formatted such as Service: Surgery")
252
        return flatten(ids)
253
254
    def get_documents_icd9_alt(self,query):
255
        '''
256
        returns: documents in DIAGNOSES_ICD given icd 9 Code query
257
        dependancy: does not depend on calling get_patients
258
        '''
259
        query = "%"+str(query)+"%"
260
        self.cur.execute("select ROW_ID, ICD9_CODE from mimic.DIAGNOSES_ICD where ICD9_CODE like '%s'" %query)
261
        raws = self.cur.fetchall()
262
        docs = []
263
        for raw in raws:
264
            print(raw)#debug
265
            if raw[1][0] != 'V' or raw[1][0] != 'E':
266
                modified = raw[1][0:3]+'.'+raw[1][3:]
267
268
            else:
269
                modified = raw[1][0:2]+'.'+raw[1][2:]
270
            print(modified)#debug
271
            rt = tree.find(modified).parent
272
            description = rt.description
273
            docs.append((raw[0],rt,description))
274
275
276
        if not docs:
277
            print("No values returned.")
278
        return docs
279
280
    def get_documents_icd9(self,code):
281
        '''
282
        returns: documents in DIAGNOSES_ICD given icd 9 Code query
283
        dependancy: does not depend on calling get_patients
284
        '''
285
        code = str(code)
286
        self.cur.execute("select ROW_ID from mimic.DIAGNOSES_ICD where ICD9_CODE = '%s'" % code)
287
        ids = self.cur.fetchall()
288
        if not ids:
289
            print("No values returned.")
290
            return None
291
        self.cur.execute("select SHORT_TITLE from mimic.D_ICD_DIAGNOSES where ICD9_CODE ='%s'" % code)
292
293
        d = {code: (flatten(self.cur.fetchall()), flatten(ids))}
294
295
        return d
296
297
    def get_prescription(self):
298
        """ TODO: NEEDS TO BE FIXED. CURRENTLY HAS IDs HARDCODED IN.
299
        """
300
        for patient in self.patients.values():
301
            self.cur.execute("select DRUG from mimic.PRESCRIPTIONS where ROW_ID = 2968759 or ROW_ID = 2968760")
302
            drugtuple = self.cur.fetchall()
303
            druglist = []
304
            for drug in drugtuple:
305
                druglist.append(drug[0])
306
            patient.addPrescriptions(druglist)
307
308
    def count_all_prescriptions(self):
309
        """ Returns a dictionary with each medicine in PRESCRIPTIONS as keys
310
            and how many times it has been prescribed as values. Takes a long time to run.
311
        """
312
        meds_dict = {}
313
        self.cur.execute("select DRUG from mimic.PRESCRIPTIONS")
314
        raw = self.cur.fetchall()
315
        meds_list = flatten(raw)
316
        for med in meds_list:
317
            if med in meds_dict:
318
                meds_dict[med] += 1
319
            else:
320
                meds_dict[med] = 1
321
322
        return meds_dict
323
324
    def get_diagnoses(self):
325
        """Adds diagnoses (converted from ICD-9 code) from DIAGNOSES_ICD to patient.diagnoses for each patient in patients dictionary.
326
        """
327
        codes = []
328
        diags = {}
329
        for patient in self.patients.values():
330
            self.cur.execute("select ICD9_CODE from mimic.DIAGNOSES_ICD where SUBJECT_ID = %d" % patient.id)
331
            codes = self.cur.fetchall()
332
            for code in codes:
333
                if code not in diags:
334
                    self.cur.execute("select LONG_TITLE from mimic.D_ICD_DIAGNOSES where ICD9_CODE = \"%s\"" % code)
335
                    diags[code] = self.cur.fetchall()
336
                patient.diagnose(diags[code])
337
338
    def get_procedures(self):
339
        """Adds procedures (converted from ICD-9 code) from PROCEDURES_ICD to patient.procedures for each patient in patients dictionary.
340
        """
341
        codes = []
342
        procs = {}
343
        for patient in self.patients.values():
344
            self.cur.execute("select ICD9_CODE from mimic.PROCEDURES_ICD where SUBJECT_ID = %d" % patient.id)
345
            codes = self.cur.fetchall()
346
            for code in codes:
347
                if code not in procs:
348
                    self.cur.execute("select LONG_TITLE from mimic.D_ICD_PROCEDURES where ICD9_CODE = \"%s\"" % code)
349
                    procs[code] = self.cur.fetchall()
350
                patient.add_procedure(procs[code])
351
352
    def extract_patient_words(self, patientID):
353
        """Uses Gensim to extract all words relevant to a patient and writes these words to a file [patientID].txt.
354
        """
355
356
        # will hold all text to be processed by gensim
357
        text = []
358
359
        if patientID in self.patients:
360
            patient = self.patients[patientID]
361
362
            # Adds note_events to text
363
            if not patient.note_events:
364
                self.get_note_events()
365
            for doc in patient.note_events:
366
                text.extend(doc[1])
367
368
            # Adds prescriptions to text
369
            if not patient.prescriptions:
370
                self.get_prescription()
371
            text.extend(patient.prescriptions)
372
373
            # # Adds diagnoses to text
374
            # if not patient.diagnosis:
375
            #     self.get_diagnoses()
376
            # text.extend([diagnosis.name for diagnosis in patient.diagnosis])
377
378
            # # Adds procedures to text
379
            # if not patient.procedures:
380
            #     self.get_procedures()
381
            # text.extend([procedure.name for procedure in patient.procedures])
382
383
        ### Cleans the documents of punctuation ###
384
        text = [sent.translate(str.maketrans('', '', string.punctuation)) for sent in text]
385
        vectorizer = TfidfVectorizer()
386
        tfidf_matrix = vectorizer.fit_transform(text)
387
        names = vectorizer.get_feature_names()
388
        doc = 0
389
        feature_index = tfidf_matrix[doc,:].nonzero()[1]
390
        scores = zip(feature_index, [tfidf_matrix[doc, x] for x in feature_index])
391
392
        print("TEMPORARY OUTPUT FOR TASK T4.4")
393
        for w, s in [(names[i], s) for (i, s) in scores]:
394
            print(w, s)
395
        return scores
396
    def extract_key_words(self, text):
397
        # code from AAN Keyword Cloud
398
        def remove_common_words_and_count(tokens):
399
            common_words = {'figure','a','able','about','above','abroad','according','accordingly','across','actually','adj','after','afterwards','again','against','ago','ahead','ain\'t','all','allow','allows','almost','alone','along','alongside','already','also','although','always','am','amid','amidst','among','amongst','an','and','another','any','anybody','anyhow','anyone','anything','anyway','anyways','anywhere','apart','appear','appreciate','appropriate','are','aren\'t','around','as','a\'s','aside','ask','asking','associated','at','available','away','awfully','b','back','backward','backwards','be','became','because','become','becomes','becoming','been','before','beforehand','begin','behind','being','believe','below','beside','besides','best','better','between','beyond','both','brief','but','by','c','came','can','cannot','cant','can\'t','caption','cause','causes','certain','certainly','changes','clearly','c\'mon','co','co.','com','come','comes','concerning','consequently','consider','considering','contain','containing','contains','corresponding','could','couldn\'t','course','c\'s','currently','d','dare','daren\'t','definitely','described','despite','did','didn\'t','different','directly','do','does','doesn\'t','doing','done','don\'t','down','downwards','during','e','each','edu','eg','eight','eighty','either','else','elsewhere','end','ending','enough','entirely','especially','et','etc','even','ever','evermore','every','everybody','everyone','everything','everywhere','ex','exactly','example','except','f','fairly','far','farther','few','fewer','fifth','first','five','followed','following','follows','for','forever','former','formerly','forth','forward','found','four','from','further','furthermore','g','get','gets','getting','given','gives','go','goes','going','gone','got','gotten','greetings','h','had','hadn\'t','half','happens','hardly','has','hasn\'t','have','haven\'t','having','he','he\'d','he\'ll','hello','help','hence','her','here','hereafter','hereby','herein','here\'s','hereupon','hers','herself','he\'s','hi','him','himself','his','hither','hopefully','how','howbeit','however','hundred','i','i\'d','ie','if','ignored','i\'ll','i\'m','immediate','in','inasmuch','inc','inc.','indeed','indicate','indicated','indicates','inner','inside','insofar','instead','into','inward','is','isn\'t','it','it\'d','it\'ll','its','it\'s','itself','i\'ve','j','just','k','keep','keeps','kept','know','known','knows','l','last','lately','later','latter','latterly','least','less','lest','let','let\'s','like','liked','likely','likewise','little','look','looking','looks','low','lower','ltd','m','made','mainly','make','makes','many','may','maybe','mayn\'t','me','mean','meantime','meanwhile','merely','might','mightn\'t','mine','minus','miss','more','moreover','most','mostly','mr','mrs','much','must','mustn\'t','my','myself','n','name','namely','nd','near','nearly','necessary','need','needn\'t','needs','neither','never','neverf','neverless','nevertheless','new','next','nine','ninety','no','nobody','non','none','nonetheless','noone','no-one','nor','normally','not','nothing','notwithstanding','novel','now','nowhere','o','obviously','of','off','often','oh','ok','okay','old','on','once','one','ones','one\'s','only','onto','opposite','or','other','others','otherwise','ought','oughtn\'t','our','ours','ourselves','out','outside','over','overall','own','p','particular','particularly','past','per','perhaps','placed','please','plus','possible','presumably','probably','provided','provides','q','que','quite','qv','r','rather','rd','re','really','reasonably','recent','recently','regarding','regardless','regards','relatively','respectively','right','round','s','said','same','saw','say','saying','says','second','secondly','see','seeing','seem','seemed','seeming','seems','seen','self','selves','sensible','sent','serious','seriously','seven','several','shall','shan\'t','she','she\'d','she\'ll','she\'s','should','shouldn\'t','since','six','so','some','somebody','someday','somehow','someone','something','sometime','sometimes','somewhat','somewhere','soon','sorry','specified','specify','specifying','still','sub','such','sup','sure','t','take','taken','taking','tell','tends','th','than','thank','thanks','thanx','that','that\'ll','thats','that\'s','that\'ve','the','their','theirs','them','themselves','then','thence','there','thereafter','thereby','there\'d','therefore','therein','there\'ll','there\'re','theres','there\'s','thereupon','there\'ve','these','they','they\'d','they\'ll','they\'re','they\'ve','thing','things','think','third','thirty','this','thorough','thoroughly','those','though','three','through','throughout','thru','thus','till','to','together','too','took','toward','towards','tried','tries','truly','try','trying','t\'s','twice','two','u','un','under','underneath','undoing','unfortunately','unless','unlike','unlikely','until','unto','up','upon','upwards','us','use','used','useful','uses','using','usually','v','value','various','versus','very','via','viz','vs','w','want','wants','was','wasn\'t','way','we','we\'d','welcome','well','we\'ll','went','were','we\'re','weren\'t','we\'ve','what','whatever','what\'ll','what\'s','what\'ve','when','whence','whenever','where','whereafter','whereas','whereby','wherein','where\'s','whereupon','wherever','whether','which','whichever','while','whilst','whither','who','who\'d','whoever','whole','who\'ll','whom','whomever','who\'s','whose','why','will','willing','wish','with','within','without','wonder','won\'t','would','wouldn\'t','x','y','yes','yet','you','you\'d','you\'ll','your','you\'re','yours','yourself','yourselves','you\'ve','z','zero'}
400
            token_counts = {}
401
            for token in tokens:
402
                token = token.lower()
403
                if token in common_words or token.isdigit() or len(token) == 1:
404
                    pass
405
                elif token in token_counts:
406
                    token_counts[token] += 1
407
                else:
408
                    token_counts[token] = 1
409
            return token_counts
410
        token_counts = remove_common_words_and_count(re.findall('[\w\-]+', text))
411
        # Sort token with highest counts first, and take top 50 only.
412
        sorted_token_counts = sorted(token_counts.items(), key=lambda x: x[1], reverse=True)[:50]
413
        scale = 48.0 / sorted_token_counts[0][1]
414
        # Normalize font size for each token such that token with largest count is size 48.
415
        token_to_font_size = [(tup[0], round(tup[1] * scale, 1)) for tup in sorted_token_counts]
416
        return sorted_token_counts
417
418
419
    def extract_phrases(self, docID):
420
        self.cur.execute("SELECT TEXT FROM mimic.NOTEEVENTS WHERE ROW_ID = %d" % docID)
421
        doc = self.cur.fetchall()
422
        upperdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
423
        f = open(upperdir+"/external/phrase-at-scale/data/raw_doc.txt", "w+")
424
        f.write(doc[0][0])
425
        f.close()
426
427
        cmd = '~/venv/lib/python3.6/site-packages/pyspark/bin/spark-submit --master local[200] --driver-memory 4G external/phrase-at-scale/phrase_generator.py'
428
        os.system(cmd)
429
430
    def output_note_events_file_by_patients(self, directory):
431
        '''
432
        input: file path like EHRKit/output/patients
433
        return: none
434
        output: Noteevents Text fields saved in EHRKit/output/patients/patient[SUBJECT_ID]/[ROW_ID].txt files
435
        '''
436
437
        #self.cur.execute('select SUBJECT_ID, count(ROW_ID) from mimic.NOTEEVENTS group by SUBJECT_ID having count(ROW_ID) > 10 limit 1')
438
        self.cur.execute('select SUBJECT_ID, count(ROW_ID) from (select SUBJECT_ID, ROW_ID from mimic.NOTEEVENTS limit 10000) as SMALLNE group by SUBJECT_ID having count(ROW_ID) > 10 limit 10')
439
        patients = self.cur.fetchall()
440
        print('Format: (Patient ID, Document count) \n', patients)
441
        for patient in patients:
442
            pid = patient[0]
443
            print('patient %d' %pid)
444
            self.cur.execute('select ROW_ID from (select SUBJECT_ID, ROW_ID from mimic.NOTEEVENTS limit 10000) as SMALLNE where SUBJECT_ID = %d' %pid)
445
            docids = self.cur.fetchall()
446
            for num,doctup in enumerate(docids, start = 1):
447
                docid = doctup[0]
448
                self.cur.execute('select TEXT from mimic.NOTEEVENTS where ROW_ID = %d' %docid)
449
                doctext = self.cur.fetchall()
450
                try:
451
                    os.makedirs(directory+'patient%d' %pid)
452
                    docpath = os.path.join(directory, 'patient%d' %pid)
453
                except FileExistsError:
454
                    docpath = os.path.join(directory, 'patient%d' %pid)
455
                with open(os.path.join(docpath, '%d.txt' %docid), 'w+') as f:
456
                    f.write(doctext[0][0])
457
                print('patient document %d saved' %docid)
458
        print('Done, please check EHRKit/Output/patients/ for files')
459
460
    def output_note_events_discharge_summary(self, directory):
461
        '''
462
        input: file path like EHRKit/output/
463
        return: none
464
        output: Noteevents Text fields saved in EHRKit/output/discharge_summary/[ROW_ID].txt files
465
        '''
466
467
        #self.cur.execute('select SUBJECT_ID, count(ROW_ID) from mimic.NOTEEVENTS group by SUBJECT_ID having count(ROW_ID) > 10 limit 1')
468
        self.cur.execute("select ROW_ID, TEXT from (select * from mimic.NOTEEVENTS limit 10000) as SMALLNE where CATEGORY = 'Discharge summary' limit 100")
469
        raw = self.cur.fetchall()
470
        for doc in raw:
471
            docid = doc[0]
472
            doctext = doc[1]
473
            print('Discharge Summary %d' %docid)
474
            try:
475
                os.makedirs(directory)
476
                docpath = directory
477
            except FileExistsError:
478
                docpath = directory
479
            with open(os.path.join(docpath, '%d.txt' %docid), 'w+') as f:
480
                f.write(doctext)
481
                print('discharge summary %d saved' %docid)
482
        print('Done, please check EHRKit/output/discharge_summary for files')
483
484
    def outputAbbreviation(self, directory):
485
        '''
486
        input: file path like EHRKit/output/
487
        return: none
488
        output: Noteevents Text files containing abbreviation “AB” in e.g. EHRKit/output/AB/194442.txt
489
        '''
490
491
492
    def count_gender(self, gender):
493
        ''' Counts how many patients there are of a certain gender in the database.
494
            Argument gender must be a capitalized single-letter string.
495
        '''
496
497
        self.cur.execute('SELECT COUNT(*) FROM mimic.PATIENTS WHERE GENDER = \'%s\'' % gender)
498
        count = self.cur.fetchall()
499
500
        return count[0][0]
501
502
    def docs_with_phrase(self, phrase):
503
        ''' Writes document text containing phrase to files named with document IDs.
504
        '''
505
506
        self.cur.execute('SELECT ROW_ID, TEXT FROM mimic.NOTEEVENTS WHERE TEXT LIKE \'%%%s%%\' LIMIT 1' % phrase)
507
        docs = self.cur.fetchall()
508
        os.mkdir("docs_with_phrase_%s" % phrase)
509
510
    #TODO: bert tokenize
511
    def get_bert_tokenize(self, doc_id):
512
        text = self.get_document(doc_id)
513
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
514
        bert_tokenized_text = tokenizer.tokenize(text)
515
        return bert_tokenized_text
516
517
    # TODO: bart sumamrize test
518
    def summarize_huggingface(self, text, model_name):
519
        if '/' in model_name:
520
            path = model_name.split('/')[1]
521
        else:
522
            path = model_name
523
524
        tokenizer_path = os.path.join(os.path.dirname(__file__), '..', 'huggingface', path, 'tokenizer')
525
        model_path = os.path.join(os.path.dirname(__file__), '..', 'huggingface', path, 'model')
526
        tokenizer = AutoTokenizer.from_pretrained('t5-small', cache_dir=tokenizer_path)
527
        model = AutoModelWithLMHead.from_pretrained(model_name, cache_dir=model_path)
528
529
        inputs = tokenizer([text], max_length=1024, return_tensors='pt')
530
        # early_stopping=True produces shorter summaries. Changing max_ and min_length doesn't change anything.
531
        summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=inputs['input_ids'].shape[1], early_stopping=False)
532
        summary = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]
533
        summary = " ".join(summary)
534
        return summary
535
536
    def bert_predict_masked(self, doc_id, sentence_id, mask_id):
537
        #TODO: FROM HUGGINGFACE LIBRARY
538
        tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased-whole-word-masking")
539
        model = AutoModelWithLMHead.from_pretrained("bert-large-uncased-whole-word-masking")
540
541
        kit_doc = self.get_document_sents(doc_id) #retrieve that doc
542
        sentence = kit_doc[sentence_id] #choose that particular sentence
543
        #print(sentence)
544
545
        #TODO: replace a random word by a masked symbol
546
        sentence_list = sentence.split(' ')
547
        sentence_list[mask_id] = tokenizer.mask_token
548
        sequence = ' '.join(sentence_list)
549
550
        input = tokenizer.encode(sequence, return_tensors="pt")
551
        mask_token_index = torch.where(input == tokenizer.mask_token_id)[1]
552
553
        token_logits = model(input)[0]
554
        mask_token_logits = token_logits[0, mask_token_index, :]
555
556
        top_token = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()
557
558
        for token in top_token:
559
            return sequence.replace(tokenizer.mask_token, tokenizer.decode([token]))
560
561
    def close_session(self):
562
        """Ends DB Session by closing SSH
563
        Tunnel and MySQL database connection.
564
        """
565
        #self.server.stop()
566
        self.cnx.close()
567
568
### ---------------- ###
569
### HELPER FUNCTIONS ###
570
### ---------------- ###
571
572
def start_session(db_user, db_pass):
573
    """Opens SQL Connection. Creates cursor
574
    for executing queries. Returns ehr_db object.
575
576
    Args:
577
578
        db_user (str): Username for MySQL DB on Tangra
579
        db_pass (str): Password for MySQL DB on Tangra
580
581
    Returns:
582
        dict: Contains SSHTunnelForwarder, pymysql connection, and
583
        pymysql cursor objects.
584
    """
585
586
587
    cnx = pymysql.connect(host='0.0.0.0',
588
                             user=db_user,
589
                             password=db_pass,port = 3306)
590
                             #port=8080)
591
    # Session Dictionary: Stores SSH Tunnel (server), MySQL Connection (cnx),
592
    # and DB Cursor(cursor).
593
    #sess_dict = {'server': server, 'cnx':cnx, 'cur':cnx.cursor()}
594
    sess_dict = {'cnx':cnx, 'cur':cnx.cursor()}
595
    # Create Session Object:
596
    sess = ehr_db(sess_dict)
597
598
    sess.cur.execute("use mimic")
599
600
    return sess
601
602
def createPatient(data):
603
    """Creates a single Patient object.
604
605
    Args:
606
        data (dict): Dictionary containing patient data
607
608
    Returns:
609
        patient: Patient object
610
    """
611
    data["diagnosis"] = getDiagnoses(data["id"], current=True)
612
    data["current_prescriptions"] = getMeds(data["id"], current=True)
613
    history = medicalHistory(data["id"])
614
    data["past_prescriptions"] = history["past_prescriptions"]
615
    data["past_diagnoses"] = history["past_diagnoses"]
616
    data["procedures"] = history["procedures"]
617
618
    patient = Patient(data)
619
620
    return patient
621
622
def flatten(lst):
623
    """Returns flattened list from nested list.
624
    """
625
    if not lst: return lst
626
    return [x for sublist in lst for x in sublist]
627
628
def numbered_print(lst):
629
    for num, elt in enumerate(lst, start = 1):
630
        print(num, '\n', elt)
631
632
633
def init_embedding_model():
634
    train_word2vec()
635
636
def get_abbs_sent_ids(text):
637
    ''' Returns a list of the abbreviations in a document along with the sentence ID they appear in
638
        in the format [(abbreviation, sent_id), ...]
639
    '''
640
    sent_list = sent_tokenize(text)
641
    abb_list = []
642
    for i, sent in zip(range(0, len(sent_list)), sent_list):
643
        for word in word_tokenize(sent):
644
            pattern = r'[A-Z]{2}'
645
            if re.match(pattern, word):
646
                abb_list.append((word, i))
647
648
    return list(abb_list)
649
def post_single_dict_to_solr(d: dict, core: str) -> None:
650
    response = requests.post('http://tangra.cs.yale.edu:8983/solr/{}/update/json/docs'.format(core), json=d)
651
652
def abbs_disambiguate(ABB):
653
    long_forms, long_form_to_score_map = get_solr_response_umn_wrap(ABB)
654
    return long_forms
655
656
def get_documents_solr(query):
657
    ids, scores = get_solr_response_mimic(query)
658
    if not ids:
659
        print("No documents found")
660
    return sorted(ids)
661
662
663
664
### ------------------- ###
665
### Tangra DB Structure ###
666
### ------------------- ###
667
668
### DIAGNOSES_ICD Table ###
669
# Description: Stores ICD-9 Diagnosis Codes for patients
670
# Source: https://mimic.physionet.org/mimictables/diagnoses_icd/
671
# ATTRIBUTES:
672
# HADM_ID = unique ID for hospital ID (possibly more than 1 per patient)
673
# SEQ_NUM = Order of priority for ICD diagnoses
674
# ICD9_CODE = ICD-9 code for patient diagnosis
675
# SUBJECT_ID = unique ID for each patient
676
677
### D_ICD_DIAGNOSES Table ###
678
# Description: Definition Table for ICD Diagnoses
679
# Source: https://mimic.physionet.org/mimictables/d_icd_diagnoses/
680
# ATTRIBUTES:
681
# SHORT_TITLE
682
# LONG_TITLE
683
# ICD9_CODE: FK on DIAGNOSES_ICD.ICD9_CODE
684
685
### D_ICD_PROCEDURES Table ###
686
# Description: Definition Table for ICD procedures
687
# Source: https://mimic.physionet.org/mimictables/d_icd_procedures/
688
# ATTRIBUTES:
689
# SHORT_TITLE
690
# LONG_TITLE
691
# ICD9_CODE: FK on DIAGNOSES_ICD.ICD9_CODE
692
693
### NOTEEVENTS Table ###
694
# Description: Stores all notes for patients
695
# Source: https://mimic.physionet.org/mimictables/noteevents/
696
# ATTRIBUTES:
697
# SUBJECT_ID = unique ID for patient
698
# HADM_ID = unique hospital admission ID
699
# CHART-DATE = timestamp for date when note was charted
700
# CATEGORY and DESCRIPTION: describe type of note
701
# CGID = unique ID for caregiver
702
# ISERROR = if 1, means physician identified note as erroneous
703
# TEXT = note text
704
705
### PATIENTS Table ###
706
# Description: Demographic chart data for all patients
707
# Source: https://mimic.physionet.org/mimictables/patients/
708
# ATTRIBUTES:
709
# SUBJECT_ID = unique ID for patient
710
# GENDER
711
# DOB
712
# DOD_HOSP: Date of death as recorded by hospital (null if alive)
713
# DOD_SSN: Date of death as recorded in social security DB. (null if alive)
714
# DOD_HOSP takes priority over DOD_SSN if both present
715
# EXPIRE_FLAG = 1 if patient dead
716
717
### PROCEDURES_ICD Table ###
718
# Description: Stores ICD-9 procedures for patients (similar to DIAGNOSES_ICD)
719
# Source: https://mimic.physionet.org/mimictables/procedures_icd/
720
# ATTRIBUTES:
721
# SUBJECT_ID = unique patient ID
722
# HADM_ID = unique hospital admission ID
723
# SEQ_NUM = order in which procedures were performed
724
# ICD9_CODE = ICD-9 code for procedure