--- a +++ b/tests/all_tests.py @@ -0,0 +1,440 @@ +import unittest +import random +import sys, os +import re +import nltk +import json +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'allennlp'))) +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'summarization', 'pubmed_summarization'))) +# print(sys.path) +from ehrkit import ehrkit +from getpass import getpass + +try: + from config import USERNAME, PASSWORD +except: + print("Please put your username and password in config.py") + USERNAME = input('DB_username?') + PASSWORD = getpass('DB_password?') + + +DOC_ID = 1354526 # Temporary!!! + + +# Number of documents in NOTEEVENTS. +NUM_DOCS = 2083180 + +# Number of patients in PATIENTS. +NUM_PATIENTS = 46520 + +# Number of diagnoses in DIAGNOSES_ICD. +NUM_DIAGS = 823933 + + +def select_ehr(ehrdb, requires_long=False, recursing=False): + if recursing: + doc_id = '' + else: + # doc_id = input("MIMIC Document ID [press Enter for random]: ") + doc_id = '' + if doc_id == '': + # Picks random document + ehrdb.cur.execute("SELECT ROW_ID FROM mimic.NOTEEVENTS ORDER BY RAND() LIMIT 1") + doc_id = ehrdb.cur.fetchall()[0][0] + text = ehrdb.get_document(int(doc_id)) + if len(text.split()) > 200 or not requires_long: + return doc_id, text + else: + return select_ehr(ehrdb, requires_long, True) + else: + # Get inputted document + try: + text = ehrdb.get_document(int(doc_id)) + return doc_id, text + except: + message = 'Error: There is no document with ID \'' + doc_id + '\' in mimic.NOTEEVENTS' + sys.exit(message) + + +def get_nb_dir(ending, SUMM_DIR): + # Gets path of Naive Bayes model trained on most examples + dir_nums = [] + for dir in os.listdir(SUMM_DIR): + if os.path.isdir(os.path.join(SUMM_DIR, dir)) and dir.endswith('_exs_' + ending): + if os.path.exists(os.path.join(SUMM_DIR, dir, 'nb')): + try: + dir_nums.append(int(dir.split('_')[0])) + except: + continue + if len(dir_nums) > 0: + best_dir_name = str(max(dir_nums)) + '_exs_' + ending + return best_dir_name + else: + return None + +def show_summary(doc_id, text, summary, model_name): + # x = input('Show full EHR (DOC ID %s)? [DEFAULT=Yes]' % doc_id) + x = '' + if x.lower() in ['y', 'yes', '']: + print('\n\n' + '-'*30 + 'Full EHR' + '-'*30) + print(text + '\n') + print('-'*80 + '\n\n') + + print('-'*30 + 'Predicted Summary ' + model_name + '-'*30) + print(summary) + print('-'*80 + '\n\n') + + +class tests(unittest.TestCase): + def setUp(self): + self.ehrdb = ehrkit.start_session(USERNAME, PASSWORD) + self.ehrdb.get_patients(3) + + +''' Runs tests 1.1-1.4 ''' +class t1(tests): + def test1_1(self): + kit_count = self.ehrdb.count_patients() + print("Patient count: ", kit_count) + + self.ehrdb.cur.execute("SELECT COUNT(*) FROM mimic.PATIENTS") + raw = self.ehrdb.cur.fetchall() + test_count = int(raw[0][0]) + + self.assertEqual(test_count, kit_count) + + def test1_2(self): + # Fails! count_docs returns 1573339, but mimic.NOTEEVENTS has 2083180 documents. + # TO DO: Fix whatever is wrong here + kit_count = self.ehrdb.count_docs(['NOTEEVENTS']) + print("Document count: ", kit_count) + + self.ehrdb.cur.execute("SELECT COUNT(*) FROM mimic.NOTEEVENTS") + raw = self.ehrdb.cur.fetchall() + test_count = int(raw[0][0]) + + self.assertEqual(test_count, kit_count) + + def test1_3(self): + self.ehrdb.get_note_events() + print('output format: SUBJECT_ID, ROW_ID, NoteEvent length') + lens = [(patient.id, note[0], len(note[1])) for patient in self.ehrdb.patients.values() for note in patient.note_events] + print(lens) + + # placeholder, this output cannot be checked easily + self.assertEqual(1, 1) + + def test1_4(self): + # Gets longest note among the patient notes queued by get_note_events() + self.ehrdb.get_note_events() + pid, rowid, doclen = self.ehrdb.longest_NE() + print('patient id is:', pid, 'NoteEvent id is:', rowid, 'length: ', doclen) + + # placeholder, this output cannot be checked easily + self.assertEqual(1, 1) + + +class t2(tests): + def test2_1(self): + ### There are 2083180 patient records in NOTEEVENTS. ### + record_id = random.randint(1, NUM_DOCS + 1) + kit_rec = self.ehrdb.get_document(record_id) + print("Document with ID %d\n: " % record_id, kit_rec) + + self.ehrdb.cur.execute("select TEXT from mimic.NOTEEVENTS where ROW_ID = %d" % record_id) + test_rec = self.ehrdb.cur.fetchall()[0][0] + + self.assertEqual(kit_rec, test_rec) + + def test2_2(self): + ### There are records from 46520 unique patients in MIMIC. ### + patient_id = random.randint(1, NUM_PATIENTS + 1) + kit_ids = self.ehrdb.get_all_patient_document_ids(patient_id) + print('Document IDs related to Patient %d: ' % patient_id, kit_ids) + print("Number of docs related to Patient %d: " % patient_id, len(kit_ids)) + + self.ehrdb.cur.execute("select ROW_ID from mimic.NOTEEVENTS where SUBJECT_ID = %d" % patient_id) + raw = self.ehrdb.cur.fetchall() + test_ids = ehrkit.flatten(raw) + + self.assertEqual(kit_ids, test_ids) + + def test2_3(self): + kit_ids = self.ehrdb.list_all_document_ids() + # print(kit_ids) + + self.ehrdb.cur.execute("select ROW_ID from mimic.NOTEEVENTS") + raw = self.ehrdb.cur.fetchall() + test_ids = ehrkit.flatten(raw) + + self.assertEqual(kit_ids, test_ids) + + def test2_4(self): + kit_ids = self.ehrdb.list_all_patient_ids() + + self.ehrdb.cur.execute("select SUBJECT_ID from mimic.PATIENTS") + raw = self.ehrdb.cur.fetchall() + test_ids = ehrkit.flatten(raw) + + self.assertEqual(kit_ids, test_ids) + + def test2_5(self): + ### Select random date from a date in the database. + ### Dates are shifted to future but preserve time, weekday, and seasonality. + random_id = random.randint(1, NUM_DOCS + 1) + self.ehrdb.cur.execute("select CHARTDATE from mimic.NOTEEVENTS where ROW_ID = %d" % random_id) + date = self.ehrdb.cur.fetchall()[0][0] + + kit_ids = self.ehrdb.get_documents_d(date) + + self.ehrdb.cur.execute("select ROW_ID from mimic.NOTEEVENTS where CHARTDATE = \"%s\"" % date) + raw = self.ehrdb.cur.fetchall() + test_ids = ehrkit.flatten(raw) + + self.assertEqual(kit_ids, test_ids) + + +class t3(tests): + def test3_1(self): + # Defines abbreviation as a string of capitalized letters + random_id = random.randint(1, NUM_DOCS + 1) + print("Collecting abbreviations for document %d..." % random_id) + kit_abbs = self.ehrdb.get_abbreviations(random_id) + + sents = self.ehrdb.get_document_sents(random_id) + test_abbs = set() + for sent in sents: + for word in ehrkit.word_tokenize(sent): + print(word) + pattern = r'[A-Z]{2}' # Only selects words in ALL CAPS + if re.match(pattern, word): + test_abbs.add(word) + + print(kit_abbs) + + self.assertEqual(kit_abbs, list(test_abbs)) + + def test3_2(self): + query = "meningitis" + # print('Printing a list of all document ids including query like ', query) + kit_ids = self.ehrdb.get_documents_q(query) + # print(kit_ids) # Extremely long list of DOC_IDs + + query = "%"+query+"%" + self.ehrdb.cur.execute("select ROW_ID from mimic.NOTEEVENTS where TEXT like \'%s\'" % query) + raw = self.ehrdb.cur.fetchall() + test_ids = ehrkit.flatten(raw) + + self.assertEqual(kit_ids, test_ids) + + def test3_3(self): + ### Task 3.3 is the same as task 3.2 with a different query. ### + query = "Service: SURGERY" + # print('Printing a list of all document ids including query like ', query) + kit_ids = self.ehrdb.get_documents_q(query) + # print(kit_ids) # Extremely long list of DOC_IDs + + query = "%"+query+"%" + self.ehrdb.cur.execute("select ROW_ID from mimic.NOTEEVENTS where TEXT like \'%s\'" % query) + raw = self.ehrdb.cur.fetchall() + test_ids = ehrkit.flatten(raw) + + self.assertEqual(kit_ids, test_ids) + + def test3_4(self): + doc_id = random.randint(1, NUM_DOCS + 1) + # print('Kit function printing a numbered list of all sentences in document %d' % doc_id) + # MIMIC EHRs are very messy and sentence tokenizaton often doesn't work + kit_doc = self.ehrdb.get_document_sents(doc_id) + # ehrkit.numbered_print(kit_doc) + + self.ehrdb.cur.execute("select TEXT from mimic.NOTEEVENTS where ROW_ID = %d " % doc_id) + raw = self.ehrdb.cur.fetchall() + test_doc = ehrkit.sent_tokenize(raw[0][0]) + + self.assertEqual(kit_doc, test_doc) + + def test3_7(self): + kit_meds = self.ehrdb.count_all_prescriptions() + + test_meds = {} + self.ehrdb.cur.execute("select DRUG from mimic.PRESCRIPTIONS") + raw = self.ehrdb.cur.fetchall() + meds_list = ehrkit.flatten(raw) + for med in meds_list: + if med in test_meds: + test_meds[med] += 1 + else: + test_meds[med] = 1 + + self.assertEqual(kit_meds, test_meds) + + +class t4(tests): + def test4_1(self): + d = self.ehrdb.get_documents_icd9() + print(d) + self.assertIsNotNone(d['code']) + + def test4_4(self): + pass + + +class t5(tests): + def test5_1(self): + doc_id = random.randint(1, NUM_DOCS + 1) + kit_phrases = self.ehrdb.extract_phrases(doc_id) + + print("Testing task 5.1\n Check phrases manually: ", kit_phrases) + + self.assertIsNotNone(kit_phrases) + + def test5_4(self): + gender = random.choice(['M', 'F']) + kit_count = self.ehrdb.count_gender(gender) + + self.ehrdb.cur.execute('SELECT COUNT(*) FROM mimic.PATIENTS WHERE GENDER = \'%s\'' % gender) + raw = self.ehrdb.cur.fetchall() + test_count = raw[0][0] + print('Gender:', gender, '\tCount:', str(test_count)) + + self.assertEqual(kit_count, test_count) + + +# class t6(tests): +# def test6_1(self): +# import loader + +# doc_id, text = select_ehr(self.ehrdb) + +# # x = input('GloVe or RoBERTa predictor [g=GloVe, r=RoBERTa]? ') +# x = 'g' +# if x == 'g': +# glove_predictor = loader.load_glove() +# probs = glove_predictor.predict(text)['probs'] +# elif x == 'r': +# roberta_predictor = loader.load_roberta() +# try: +# probs = roberta_predictor.predict(text)['probs'] +# except: +# print('Document too long for RoBERTa model. Using GLoVe instead.') +# glove_predictor = loader.load_glove() +# probs = glove_predictor.predict(text)['probs'] +# else: +# sys.exit('Error: Must input \'g\' or \'r\'') + +# classification = 'Positive' if probs[0] >= 0.5 else 'Negative' +# print("Document ID: ", doc_id, "\tPredicted Sentiment: ", classification) + +# def test6_2(self): +# import loader + +# doc_id, text = select_ehr(self.ehrdb) + +# if os.path.exists("../allennlp/elmo-ner/whole_model.pt"): +# predictor = loader.load_ner() +# else: +# predictor = loader.download_ner() + +# text = self.ehrdb.get_document(int(doc_id)) +# pred = predictor.predict(text) +# # pred = predictor.predict("John likes and Bill hates ice cream") +# # print_results = input("Prediction complete. Print results? (y/n): ") +# print_results='y' +# if print_results == "y": +# print("Document ID: ", doc_id, " Results: ", pred['tags']) + +# def test6_3(self): +# import torch +# from transformers import BertTokenizer#, BertModel, BertForMaskedLM + +# doc_id, text = select_ehr(self.ehrdb) +# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') +# bert_tokenized_text = tokenizer.tokenize(text) +# print('\n' + '-'*20 + 'text' + '-'*20) +# print(text) +# print('\n' + '-'*20 + 'Tokenized text from Huggingface BERT Tokenizer' + '-'*20) +# print(bert_tokenized_text) + + +# # library function +# ehr_bert_tokenized_text = self.ehrdb.get_bert_tokenize(doc_id) +# self.assertEqual(bert_tokenized_text, ehr_bert_tokenized_text) + + +class t7(tests): + # Summarization algorithms + def test7_1(self): + from pubmed_naive_bayes import classify_nb + from get_pubmed_nb_data import build_vecs + from sklearn.naive_bayes import GaussianNB + + doc_id, text = select_ehr(self.ehrdb) + # body_type = input('Use Naive Bayes model trained from whole body sections or just their body introductions?\n\t'\ + # '[w=whole body, j=just intro, DEFAULT=just intro]: ') + body_type = 'j' + if body_type == 'w': + ending = 'body' + elif body_type in ['j', '']: + ending = 'intro' + else: + sys.exit('Error: Must input \'w\' or \'j.\'') + SUMM_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'summarization', 'pubmed_summarization')) + best_dir_name = get_nb_dir(ending, SUMM_DIR) + if not best_dir_name: + message = 'No Naive Bayes models of this type have been fit. '\ + 'Would you like to do so now?\n\t[DEFAULT=Yes] ' + # response = input(message) + response = '' + if response.lower() in ['y', 'yes', '']: + command = 'python ' + os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'summarization', 'pubmed_summarization', 'pubmed_naive_bayes.py')) + os.system(command) + best_dir_name = get_nb_dir(ending) + if response.lower() not in ['y', 'yes', ''] or not best_dir_name: + sys.exit('Exiting.') + + # Fits model to data + NB_DIR = os.path.join(SUMM_DIR, best_dir_name, 'nb') + with open(os.path.join(NB_DIR, 'feature_vecs.json'), 'r') as f: + data = json.load(f) + xtrain, ytrain = data['train_features'], data['train_outputs'] + gnb = GaussianNB() + gnb.fit(xtrain, ytrain) + + # Evaluates on model + tokenizer = nltk.data.load('tokenizers/punkt/english.pickle') + feature_vecs, _ = build_vecs(text, None, tokenizer) + PCT_SUM = 0.3 + preds = classify_nb(feature_vecs, PCT_SUM, gnb) + sents = tokenizer.tokenize(text) + summary = '' + for i in range(len(preds)): + if preds[i] == 1: + summary += sents[i] + + show_summary(doc_id, text, summary, 'Naive Bayes') + + def test7_2(self): + # Distilbart for summarization. Trained on CNN/ Daily Mail (~4x longer summaries than XSum) + doc_id, text = select_ehr(self.ehrdb, requires_long=True) + model_name = 'sshleifer/distilbart-cnn-12-6' + summary = self.ehrdb.summarize_huggingface(text, model_name) + + show_summary(doc_id, text, summary, model_name) + print('Number of Words in Full EHR: %d' % len(text.split())) + print('Number of Words in %s Summary: %d' % (model_name, len(summary.split()))) + + def test7_3(self): + # T5 for summarization. Trained on CNN/ Daily Mail + doc_id, text = select_ehr(self.ehrdb, requires_long=True) + model_name = 't5-small' + summary = self.ehrdb.summarize_huggingface(text, model_name) + + show_summary(doc_id, text, summary, model_name) + print('Number of Words in Full EHR: %d' % len(text.split())) + print('Number of Words in %s Summary: %d' % (model_name, len(summary.split()))) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file