a b/tests/all_tests.py
1
import unittest
2
import random
3
import sys, os
4
import re
5
import nltk
6
import json
7
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'allennlp')))
9
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'summarization', 'pubmed_summarization')))
10
# print(sys.path)
11
from ehrkit import ehrkit
12
from getpass import getpass
13
14
try: 
15
    from config import USERNAME, PASSWORD
16
except:
17
    print("Please put your username and password in config.py")
18
    USERNAME = input('DB_username?')
19
    PASSWORD = getpass('DB_password?')
20
21
22
DOC_ID = 1354526 # Temporary!!!
23
24
25
# Number of documents in NOTEEVENTS.
26
NUM_DOCS = 2083180
27
28
# Number of patients in PATIENTS.
29
NUM_PATIENTS = 46520
30
31
# Number of diagnoses in DIAGNOSES_ICD.
32
NUM_DIAGS = 823933
33
34
35
def select_ehr(ehrdb, requires_long=False, recursing=False):
36
    if recursing:
37
        doc_id = ''
38
    else:
39
        # doc_id = input("MIMIC Document ID [press Enter for random]: ")
40
        doc_id = ''
41
    if doc_id == '':
42
        # Picks random document
43
        ehrdb.cur.execute("SELECT ROW_ID FROM mimic.NOTEEVENTS ORDER BY RAND() LIMIT 1")
44
        doc_id = ehrdb.cur.fetchall()[0][0]
45
        text = ehrdb.get_document(int(doc_id))
46
        if len(text.split()) > 200 or not requires_long:
47
            return doc_id, text
48
        else:
49
            return select_ehr(ehrdb, requires_long, True)
50
    else:
51
        # Get inputted document
52
        try:
53
            text = ehrdb.get_document(int(doc_id))
54
            return doc_id, text
55
        except:
56
            message = 'Error: There is no document with ID \'' + doc_id + '\' in mimic.NOTEEVENTS'
57
            sys.exit(message)
58
59
60
def get_nb_dir(ending, SUMM_DIR):
61
    # Gets path of Naive Bayes model trained on most examples
62
    dir_nums = []
63
    for dir in os.listdir(SUMM_DIR):
64
        if os.path.isdir(os.path.join(SUMM_DIR, dir)) and dir.endswith('_exs_' + ending):
65
            if os.path.exists(os.path.join(SUMM_DIR, dir, 'nb')):  
66
                try:
67
                    dir_nums.append(int(dir.split('_')[0]))
68
                except:
69
                    continue
70
    if len(dir_nums) > 0:
71
        best_dir_name = str(max(dir_nums)) + '_exs_' + ending
72
        return best_dir_name
73
    else:
74
        return None
75
76
def show_summary(doc_id, text, summary, model_name):
77
    # x = input('Show full EHR (DOC ID %s)? [DEFAULT=Yes]' % doc_id)
78
    x = ''
79
    if x.lower() in ['y', 'yes', '']:
80
        print('\n\n' + '-'*30 + 'Full EHR' + '-'*30)
81
        print(text + '\n')
82
        print('-'*80 + '\n\n')
83
84
    print('-'*30 + 'Predicted Summary ' + model_name + '-'*30)
85
    print(summary)
86
    print('-'*80 + '\n\n')
87
88
89
class tests(unittest.TestCase):
90
    def setUp(self):
91
        self.ehrdb = ehrkit.start_session(USERNAME, PASSWORD)
92
        self.ehrdb.get_patients(3)
93
94
95
''' Runs tests 1.1-1.4 '''
96
class t1(tests):
97
    def test1_1(self):
98
        kit_count = self.ehrdb.count_patients()
99
        print("Patient count: ", kit_count)
100
101
        self.ehrdb.cur.execute("SELECT COUNT(*) FROM mimic.PATIENTS")
102
        raw = self.ehrdb.cur.fetchall()
103
        test_count = int(raw[0][0])
104
105
        self.assertEqual(test_count, kit_count)
106
107
    def test1_2(self):
108
        # Fails! count_docs returns 1573339, but mimic.NOTEEVENTS has 2083180 documents. 
109
        # TO DO: Fix whatever is wrong here
110
        kit_count = self.ehrdb.count_docs(['NOTEEVENTS'])
111
        print("Document count: ", kit_count)
112
113
        self.ehrdb.cur.execute("SELECT COUNT(*) FROM mimic.NOTEEVENTS")
114
        raw = self.ehrdb.cur.fetchall()
115
        test_count = int(raw[0][0])
116
117
        self.assertEqual(test_count, kit_count)
118
119
    def test1_3(self):
120
        self.ehrdb.get_note_events()
121
        print('output format: SUBJECT_ID, ROW_ID, NoteEvent length')
122
        lens = [(patient.id, note[0], len(note[1])) for patient in self.ehrdb.patients.values() for note in patient.note_events]
123
        print(lens)
124
125
        # placeholder, this output cannot be checked easily
126
        self.assertEqual(1, 1)
127
128
    def test1_4(self):
129
        # Gets longest note among the patient notes queued by get_note_events()
130
        self.ehrdb.get_note_events()
131
        pid, rowid, doclen = self.ehrdb.longest_NE()
132
        print('patient id is:', pid, 'NoteEvent id is:', rowid, 'length: ', doclen)
133
134
        # placeholder, this output cannot be checked easily
135
        self.assertEqual(1, 1)
136
137
138
class t2(tests):
139
    def test2_1(self):
140
        ### There are 2083180 patient records in NOTEEVENTS. ###
141
        record_id = random.randint(1, NUM_DOCS + 1)
142
        kit_rec = self.ehrdb.get_document(record_id)
143
        print("Document with ID %d\n: " % record_id, kit_rec)
144
145
        self.ehrdb.cur.execute("select TEXT from mimic.NOTEEVENTS where ROW_ID = %d" % record_id)
146
        test_rec = self.ehrdb.cur.fetchall()[0][0]
147
148
        self.assertEqual(kit_rec, test_rec)
149
150
    def test2_2(self):
151
        ### There are records from 46520 unique patients in MIMIC. ###
152
        patient_id = random.randint(1, NUM_PATIENTS + 1)
153
        kit_ids = self.ehrdb.get_all_patient_document_ids(patient_id)
154
        print('Document IDs related to Patient %d: ' % patient_id, kit_ids)
155
        print("Number of docs related to Patient %d: " % patient_id, len(kit_ids))
156
157
        self.ehrdb.cur.execute("select ROW_ID from mimic.NOTEEVENTS where SUBJECT_ID = %d" % patient_id)
158
        raw = self.ehrdb.cur.fetchall()
159
        test_ids = ehrkit.flatten(raw)
160
161
        self.assertEqual(kit_ids, test_ids)
162
163
    def test2_3(self):
164
        kit_ids = self.ehrdb.list_all_document_ids()
165
        # print(kit_ids)
166
167
        self.ehrdb.cur.execute("select ROW_ID from mimic.NOTEEVENTS")
168
        raw = self.ehrdb.cur.fetchall()
169
        test_ids = ehrkit.flatten(raw)
170
171
        self.assertEqual(kit_ids, test_ids)
172
173
    def test2_4(self):
174
        kit_ids = self.ehrdb.list_all_patient_ids()
175
176
        self.ehrdb.cur.execute("select SUBJECT_ID from mimic.PATIENTS")
177
        raw = self.ehrdb.cur.fetchall()
178
        test_ids = ehrkit.flatten(raw)
179
180
        self.assertEqual(kit_ids, test_ids)
181
182
    def test2_5(self):
183
        ### Select random date from a date in the database. 
184
        ### Dates are shifted to future but preserve time, weekday, and seasonality.
185
        random_id = random.randint(1, NUM_DOCS + 1)
186
        self.ehrdb.cur.execute("select CHARTDATE from mimic.NOTEEVENTS where ROW_ID = %d" % random_id)
187
        date = self.ehrdb.cur.fetchall()[0][0]
188
189
        kit_ids = self.ehrdb.get_documents_d(date)
190
191
        self.ehrdb.cur.execute("select ROW_ID from mimic.NOTEEVENTS where CHARTDATE = \"%s\"" % date)
192
        raw = self.ehrdb.cur.fetchall()
193
        test_ids = ehrkit.flatten(raw)
194
195
        self.assertEqual(kit_ids, test_ids)
196
197
198
class t3(tests):
199
    def test3_1(self):
200
        # Defines abbreviation as a string of capitalized letters
201
        random_id = random.randint(1, NUM_DOCS + 1)
202
        print("Collecting abbreviations for document %d..." % random_id)
203
        kit_abbs = self.ehrdb.get_abbreviations(random_id)
204
205
        sents = self.ehrdb.get_document_sents(random_id)
206
        test_abbs = set()
207
        for sent in sents:
208
            for word in ehrkit.word_tokenize(sent):
209
                print(word)
210
                pattern = r'[A-Z]{2}'  # Only selects words in ALL CAPS
211
                if re.match(pattern, word):
212
                    test_abbs.add(word)
213
214
        print(kit_abbs)
215
216
        self.assertEqual(kit_abbs, list(test_abbs))
217
218
    def test3_2(self):
219
        query = "meningitis"
220
        # print('Printing a list of all document ids including query like ', query)
221
        kit_ids = self.ehrdb.get_documents_q(query)
222
        # print(kit_ids)  # Extremely long list of DOC_IDs
223
224
        query = "%"+query+"%"
225
        self.ehrdb.cur.execute("select ROW_ID from mimic.NOTEEVENTS where TEXT like \'%s\'" % query)
226
        raw = self.ehrdb.cur.fetchall()
227
        test_ids = ehrkit.flatten(raw)
228
229
        self.assertEqual(kit_ids, test_ids)
230
231
    def test3_3(self):
232
        ### Task 3.3 is the same as task 3.2 with a different query. ###
233
        query = "Service: SURGERY"
234
        # print('Printing a list of all document ids including query like ', query)
235
        kit_ids = self.ehrdb.get_documents_q(query)
236
        # print(kit_ids)  # Extremely long list of DOC_IDs
237
238
        query = "%"+query+"%"
239
        self.ehrdb.cur.execute("select ROW_ID from mimic.NOTEEVENTS where TEXT like \'%s\'" % query)
240
        raw = self.ehrdb.cur.fetchall()
241
        test_ids = ehrkit.flatten(raw)
242
243
        self.assertEqual(kit_ids, test_ids)
244
245
    def test3_4(self):
246
        doc_id = random.randint(1, NUM_DOCS + 1)
247
        # print('Kit function printing a numbered list of all sentences in document %d' % doc_id)
248
        # MIMIC EHRs are very messy and sentence tokenizaton often doesn't work
249
        kit_doc = self.ehrdb.get_document_sents(doc_id)
250
        # ehrkit.numbered_print(kit_doc)
251
252
        self.ehrdb.cur.execute("select TEXT from mimic.NOTEEVENTS where ROW_ID = %d " % doc_id)
253
        raw = self.ehrdb.cur.fetchall()
254
        test_doc = ehrkit.sent_tokenize(raw[0][0])
255
256
        self.assertEqual(kit_doc, test_doc)
257
258
    def test3_7(self):
259
        kit_meds = self.ehrdb.count_all_prescriptions()
260
261
        test_meds = {}
262
        self.ehrdb.cur.execute("select DRUG from mimic.PRESCRIPTIONS")
263
        raw = self.ehrdb.cur.fetchall()
264
        meds_list = ehrkit.flatten(raw)
265
        for med in meds_list:
266
            if med in test_meds:
267
                test_meds[med] += 1
268
            else:
269
                test_meds[med] = 1
270
271
        self.assertEqual(kit_meds, test_meds)
272
273
274
class t4(tests):
275
    def test4_1(self):
276
        d = self.ehrdb.get_documents_icd9()
277
        print(d)
278
        self.assertIsNotNone(d['code'])
279
280
    def test4_4(self):
281
        pass
282
283
284
class t5(tests):
285
    def test5_1(self):
286
        doc_id = random.randint(1, NUM_DOCS + 1)
287
        kit_phrases = self.ehrdb.extract_phrases(doc_id)
288
289
        print("Testing task 5.1\n Check phrases manually: ", kit_phrases)
290
291
        self.assertIsNotNone(kit_phrases)
292
293
    def test5_4(self):
294
        gender = random.choice(['M', 'F'])
295
        kit_count = self.ehrdb.count_gender(gender)
296
297
        self.ehrdb.cur.execute('SELECT COUNT(*) FROM mimic.PATIENTS WHERE GENDER = \'%s\'' % gender)
298
        raw = self.ehrdb.cur.fetchall()
299
        test_count = raw[0][0]
300
        print('Gender:', gender, '\tCount:', str(test_count))
301
302
        self.assertEqual(kit_count, test_count)
303
304
305
# class t6(tests):
306
#     def test6_1(self):
307
#         import loader
308
309
#         doc_id, text = select_ehr(self.ehrdb)
310
311
#         # x = input('GloVe or RoBERTa predictor [g=GloVe, r=RoBERTa]? ')
312
#         x = 'g'
313
#         if x == 'g':
314
#             glove_predictor = loader.load_glove()
315
#             probs = glove_predictor.predict(text)['probs']
316
#         elif x == 'r':
317
#             roberta_predictor = loader.load_roberta()
318
#             try:
319
#                 probs = roberta_predictor.predict(text)['probs']
320
#             except:
321
#                 print('Document too long for RoBERTa model. Using GLoVe instead.')
322
#                 glove_predictor = loader.load_glove()
323
#                 probs = glove_predictor.predict(text)['probs']
324
#         else:
325
#             sys.exit('Error: Must input \'g\' or  \'r\'')
326
327
#         classification = 'Positive' if probs[0] >= 0.5 else 'Negative'
328
#         print("Document ID: ", doc_id, "\tPredicted Sentiment: ", classification)
329
330
#     def test6_2(self):
331
#         import loader
332
333
#         doc_id, text = select_ehr(self.ehrdb)
334
335
#         if os.path.exists("../allennlp/elmo-ner/whole_model.pt"):
336
#             predictor = loader.load_ner()
337
#         else:
338
#             predictor = loader.download_ner()
339
340
#         text = self.ehrdb.get_document(int(doc_id))
341
#         pred = predictor.predict(text)
342
#         # pred = predictor.predict("John likes and Bill hates ice cream")
343
#         # print_results = input("Prediction complete. Print results? (y/n): ")
344
#         print_results='y'
345
#         if print_results == "y":
346
#             print("Document ID: ", doc_id, "  Results: ", pred['tags'])
347
    
348
#     def test6_3(self):
349
#         import torch
350
#         from transformers import BertTokenizer#, BertModel, BertForMaskedLM
351
352
#         doc_id, text = select_ehr(self.ehrdb)
353
#         tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
354
#         bert_tokenized_text = tokenizer.tokenize(text)
355
#         print('\n' + '-'*20 + 'text' + '-'*20)
356
#         print(text)
357
#         print('\n' + '-'*20 + 'Tokenized text from Huggingface BERT Tokenizer' + '-'*20)    
358
#         print(bert_tokenized_text)
359
360
361
#         # library function
362
#         ehr_bert_tokenized_text = self.ehrdb.get_bert_tokenize(doc_id)
363
#         self.assertEqual(bert_tokenized_text, ehr_bert_tokenized_text)
364
365
366
class t7(tests):
367
    # Summarization algorithms
368
    def test7_1(self):
369
        from pubmed_naive_bayes import classify_nb
370
        from get_pubmed_nb_data import build_vecs
371
        from sklearn.naive_bayes import GaussianNB
372
373
        doc_id, text = select_ehr(self.ehrdb)
374
        # body_type = input('Use Naive Bayes model trained from whole body sections or just their body introductions?\n\t'\
375
                        # '[w=whole body, j=just intro, DEFAULT=just intro]: ')
376
        body_type = 'j'
377
        if body_type == 'w':
378
            ending = 'body'
379
        elif body_type in ['j', '']:
380
            ending = 'intro'
381
        else:
382
            sys.exit('Error: Must input \'w\' or \'j.\'')
383
        SUMM_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'summarization', 'pubmed_summarization'))
384
        best_dir_name = get_nb_dir(ending, SUMM_DIR)
385
        if not best_dir_name:
386
            message = 'No Naive Bayes models of this type have been fit. '\
387
                        'Would you like to do so now?\n\t[DEFAULT=Yes] '
388
            # response = input(message)
389
            response = ''
390
            if response.lower() in ['y', 'yes', '']:
391
                command = 'python ' + os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'summarization', 'pubmed_summarization', 'pubmed_naive_bayes.py'))
392
                os.system(command)
393
                best_dir_name = get_nb_dir(ending)
394
            if response.lower() not in ['y', 'yes', ''] or not best_dir_name:
395
                sys.exit('Exiting.')
396
397
        # Fits model to data        
398
        NB_DIR = os.path.join(SUMM_DIR, best_dir_name, 'nb')
399
        with open(os.path.join(NB_DIR, 'feature_vecs.json'), 'r') as f:
400
            data = json.load(f)
401
        xtrain, ytrain = data['train_features'], data['train_outputs']
402
        gnb = GaussianNB()
403
        gnb.fit(xtrain, ytrain)
404
405
        # Evaluates on model
406
        tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
407
        feature_vecs, _ = build_vecs(text, None, tokenizer)
408
        PCT_SUM = 0.3
409
        preds = classify_nb(feature_vecs, PCT_SUM, gnb)
410
        sents = tokenizer.tokenize(text)
411
        summary = ''
412
        for i in range(len(preds)):
413
            if preds[i] == 1:
414
                summary += sents[i]
415
416
        show_summary(doc_id, text, summary, 'Naive Bayes')
417
        
418
    def test7_2(self):
419
        # Distilbart for summarization. Trained on CNN/ Daily Mail (~4x longer summaries than XSum)
420
        doc_id, text = select_ehr(self.ehrdb, requires_long=True)
421
        model_name = 'sshleifer/distilbart-cnn-12-6'
422
        summary = self.ehrdb.summarize_huggingface(text, model_name)
423
424
        show_summary(doc_id, text, summary, model_name)
425
        print('Number of Words in Full EHR: %d' % len(text.split()))
426
        print('Number of Words in %s Summary: %d' % (model_name, len(summary.split())))
427
428
    def test7_3(self):
429
        # T5 for summarization. Trained on CNN/ Daily Mail
430
        doc_id, text = select_ehr(self.ehrdb, requires_long=True)
431
        model_name = 't5-small'
432
        summary = self.ehrdb.summarize_huggingface(text, model_name)
433
434
        show_summary(doc_id, text, summary, model_name)
435
        print('Number of Words in Full EHR: %d' % len(text.split()))
436
        print('Number of Words in %s Summary: %d' % (model_name, len(summary.split())))
437
438
439
if __name__ == '__main__':
440
    unittest.main()