--- a +++ b/tests/tests.py @@ -0,0 +1,472 @@ +import unittest +import random +import sys, os +import re +import nltk +import json +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'allennlp'))) +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'summarization', 'pubmed_summarization'))) +# print(sys.path) +from ehrkit import ehrkit +from getpass import getpass + +try: + from config import USERNAME, PASSWORD +except: + print("Please put your username and password in config.py") + USERNAME = input('DB_username?') + PASSWORD = getpass('DB_password?') + + +DOC_ID = 1354526 # Temporary!!! + + +# Number of documents in NOTEEVENTS. +NUM_DOCS = 2083180 + +# Number of patients in PATIENTS. +NUM_PATIENTS = 46520 + +# Number of diagnoses in DIAGNOSES_ICD. +NUM_DIAGS = 823933 + + +def select_ehr(ehrdb, requires_long=False, recursing=False): + if recursing: + doc_id = '' + else: + #doc_id = input("MIMIC Document ID [press Enter for random]: ") + + doc_id = '' + + if doc_id == '': + # Picks random document + ehrdb.cur.execute("SELECT ROW_ID FROM mimic.NOTEEVENTS ORDER BY RAND() LIMIT 1") + doc_id = ehrdb.cur.fetchall()[0][0] + text = ehrdb.get_document(int(doc_id)) + if len(text.split()) > 200 or not requires_long: + return doc_id, text + else: + return select_ehr(ehrdb, requires_long, True) + else: + # Get inputted document + try: + text = ehrdb.get_document(int(doc_id)) + return doc_id, text + except: + message = 'Error: There is no document with ID \'' + doc_id + '\' in mimic.NOTEEVENTS' + sys.exit(message) + + +def get_nb_dir(ending, SUMM_DIR): + # Gets path of Naive Bayes model trained on most examples + dir_nums = [] + for dir in os.listdir(SUMM_DIR): + if os.path.isdir(os.path.join(SUMM_DIR, dir)) and dir.endswith('_exs_' + ending): + if os.path.exists(os.path.join(SUMM_DIR, dir, 'nb')): + try: + dir_nums.append(int(dir.split('_')[0])) + except: + continue + if len(dir_nums) > 0: + best_dir_name = str(max(dir_nums)) + '_exs_' + ending + return best_dir_name + else: + return None + +def show_summary(doc_id, text, summary, model_name): + #x = input('Show full EHR (DOC ID %s)? [DEFAULT=Yes]' % doc_id) + x = '' + if x.lower() in ['y', 'yes', '']: + print('\n\n' + '-'*30 + 'Full EHR' + '-'*30) + print(text + '\n') + print('-'*80 + '\n\n') + + print('-'*30 + 'Predicted Summary ' + model_name + '-'*30) + print(summary) + print('-'*80 + '\n\n') + + +class tests(unittest.TestCase): + def setUp(self): + self.ehrdb = ehrkit.start_session(USERNAME, PASSWORD) + self.ehrdb.get_patients(3) + + +''' Runs tests 1.1-1.4 ''' +class t1(tests): + def test1_1_count_patients(self): + kit_count = self.ehrdb.count_patients() + print("Patient count: ", kit_count) + + self.ehrdb.cur.execute("SELECT COUNT(*) FROM mimic.PATIENTS") + raw = self.ehrdb.cur.fetchall() + test_count = int(raw[0][0]) + + self.assertEqual(test_count, kit_count) + + # def test1_2_count_docs(self): + # # Fails! count_docs returns 1573339, but mimic.NOTEEVENTS has 2083180 documents. + # # TO DO: Fix whatever is wrong here + # kit_count = self.ehrdb.count_docs(['NOTEEVENTS']) + # print("Document count: ", kit_count) + + # self.ehrdb.cur.execute("SELECT COUNT(*) FROM mimic.NOTEEVENTS") + # raw = self.ehrdb.cur.fetchall() + # test_count = int(raw[0][0]) + + # self.assertEqual(test_count, kit_count) + + def test1_3_note_info(self): + self.ehrdb.get_note_events() + print('output format: SUBJECT_ID, ROW_ID, NoteEvent length') + lens = [(patient.id, note[0], len(note[1])) for patient in self.ehrdb.patients.values() for note in patient.note_events] + print(lens) + + # placeholder, this output cannot be checked easily + self.assertEqual(1, 1) + + def test1_4_longest_note(self): + # Gets longest note among the patient notes queued by get_note_events() + self.ehrdb.get_note_events() + pid, rowid, doclen = self.ehrdb.longest_NE() + print('patient id is:', pid, 'NoteEvent id is:', rowid, 'length: ', doclen) + + # placeholder, this output cannot be checked easily + self.assertEqual(1, 1) + + +class t2(tests): + def test2_1_print_note(self): + ### There are 2083180 patient records in NOTEEVENTS. ### + record_id = random.randint(1, NUM_DOCS + 1) + kit_rec = self.ehrdb.get_document(record_id) + print("Document with ID %d\n: " % record_id, kit_rec) + + self.ehrdb.cur.execute("select TEXT from mimic.NOTEEVENTS where ROW_ID = %d" % record_id) + test_rec = self.ehrdb.cur.fetchall()[0][0] + + self.assertEqual(kit_rec, test_rec) + + def test2_2_patient_info(self): + ### There are records from 46520 unique patients in MIMIC. ### + patient_id = random.randint(1, NUM_PATIENTS + 1) + kit_ids = self.ehrdb.get_all_patient_document_ids(patient_id) + print('Document IDs related to Patient %d: ' % patient_id, kit_ids) + print("Number of docs related to Patient %d: " % patient_id, len(kit_ids)) + + self.ehrdb.cur.execute("select ROW_ID from mimic.NOTEEVENTS where SUBJECT_ID = %d" % patient_id) + raw = self.ehrdb.cur.fetchall() + test_ids = ehrkit.flatten(raw) + + self.assertEqual(kit_ids, test_ids) + + #@unittest.skipIf("t2.test2_3" not in sys.argv, "Test 2_3 must be run explicitly due to runtime.") + def test2_3_doc_ids(self): + kit_ids = self.ehrdb.list_all_document_ids() + + self.ehrdb.cur.execute("select ROW_ID from mimic.NOTEEVENTS") + raw = self.ehrdb.cur.fetchall() + test_ids = ehrkit.flatten(raw) + print('test_ids') + print(test_ids[:30]) + print('...') + + self.assertEqual(kit_ids, test_ids) + + def test2_4_patient_ids(self): + kit_ids = self.ehrdb.list_all_patient_ids() + + self.ehrdb.cur.execute("select SUBJECT_ID from mimic.PATIENTS") + raw = self.ehrdb.cur.fetchall() + test_ids = ehrkit.flatten(raw) + print(test_ids[:30]) + print('...') + + self.assertEqual(kit_ids, test_ids) + + #@unittest.skipIf("t2.test2_5" not in sys.argv, "Test 2_5 must be run explicitly due to runtime.") + def test2_5_docs_on_date(self): + ### Select random date from a date in the database. + ### Dates are shifted to future but preserve time, weekday, and seasonality. + random_id = random.randint(1, NUM_DOCS + 1) + self.ehrdb.cur.execute("select CHARTDATE from mimic.NOTEEVENTS where ROW_ID = %d" % random_id) + date = self.ehrdb.cur.fetchall()[0][0] + + kit_ids = self.ehrdb.get_documents_d(date) + + self.ehrdb.cur.execute("select ROW_ID from mimic.NOTEEVENTS where CHARTDATE = \"%s\"" % date) + raw = self.ehrdb.cur.fetchall() + test_ids = ehrkit.flatten(raw) + print(f"Selected date: {date}") + print(f"Test ids {test_ids[:30]} ...") + + self.assertEqual(kit_ids, test_ids) + + +class t3(tests): + def test3_1_extract_abbreviations(self): + # Defines abbreviation as a string of capitalized letters + random_id = random.randint(1, NUM_DOCS + 1) + print("Collecting abbreviations for document %d..." % random_id) + kit_abbs = self.ehrdb.get_abbreviations(random_id) + + sents = self.ehrdb.get_document_sents(random_id) + test_abbs = set() + for sent in sents: + for word in ehrkit.word_tokenize(sent): + print(word) + pattern = r'[A-Z]{2}' # Only selects words in ALL CAPS + if re.match(pattern, word): + test_abbs.add(word) + + print(kit_abbs) + + self.assertEqual(kit_abbs, list(test_abbs)) + + #@unittest.skipIf("t3.test3_2" not in sys.argv, "Test 3_2 must be run explicitly due to runtime.") + def test3_2_docs_with_query(self): + query = "meningitis" + print('Printing a list of all document ids including query like ', query) + kit_ids = self.ehrdb.get_documents_q(query) + print(kit_ids[:30]) # Extremely long list of DOC_IDs + print("...") + + query = "%"+query+"%" + self.ehrdb.cur.execute("select ROW_ID from mimic.NOTEEVENTS where TEXT like \'%s\'" % query) + raw = self.ehrdb.cur.fetchall() + test_ids = ehrkit.flatten(raw) + + self.assertEqual(kit_ids, test_ids) + + #@unittest.skipIf("t3.test3_3" not in sys.argv, "Test 3_3 must be run explicitly due to runtime. Also, this is essentially a duplicate of task 3.2.") + def test3_3_query_docs(self): + ### Task 3.3 is the same as task 3.2 with a different query. ### + query = "Service: SURGERY" + print('Printing a list of all document ids including query like ', query) + kit_ids = self.ehrdb.get_documents_q(query) + print(kit_ids[:30]) # Extremely long list of DOC_IDs + print("...") + + query = "%"+query+"%" + self.ehrdb.cur.execute("select ROW_ID from mimic.NOTEEVENTS where TEXT like \'%s\'" % query) + raw = self.ehrdb.cur.fetchall() + test_ids = ehrkit.flatten(raw) + + self.assertEqual(kit_ids, test_ids) + + def test3_4_doc_sentences(self): + doc_id = random.randint(1, NUM_DOCS + 1) + print('Kit function printing a numbered list of all sentences in document %d' % doc_id) + # MIMIC EHRs are very messy and sentence tokenizaton often doesn't work + kit_doc = self.ehrdb.get_document_sents(doc_id) + ehrkit.numbered_print(kit_doc) + + self.ehrdb.cur.execute("select TEXT from mimic.NOTEEVENTS where ROW_ID = %d " % doc_id) + raw = self.ehrdb.cur.fetchall() + test_doc = ehrkit.sent_tokenize(raw[0][0]) + print(test_doc) + + self.assertEqual(kit_doc, test_doc) + + #@unittest.skipIf("t3.test3_7" not in sys.argv, "Test 3_7 must be run explicitly due to runtime.") + def test3_7_medications(self): + kit_meds = self.ehrdb.count_all_prescriptions() + + test_meds = {} + self.ehrdb.cur.execute("select DRUG from mimic.PRESCRIPTIONS") + raw = self.ehrdb.cur.fetchall() + meds_list = ehrkit.flatten(raw) + for med in meds_list: + if med in test_meds: + test_meds[med] += 1 + else: + test_meds[med] = 1 + + print(meds_list[:30]) + print("...") + + self.assertEqual(kit_meds, test_meds) + + +class t4(tests): + @unittest.skip("Task 4.1 is not ready to be tested yet.") + def test4_1(self): + d = self.ehrdb.get_documents_icd9() + print(d) + self.assertIsNotNone(d['code']) + + @unittest.skip("Task 4.4 is not ready to be tested yet.") + def test4_4(self): + pass + + +class t5(tests): + @unittest.skipIf("t5.test5_1" not in sys.argv, "Test 5_1 must be run explicitly due to runtime.") + def test5_1_extract_phrases(self): + doc_id = random.randint(1, NUM_DOCS + 1) + kit_phrases = self.ehrdb.extract_phrases(doc_id) + + print("Testing task 5.1\n Check phrases manually: ", kit_phrases) + + self.assertIsNotNone(kit_phrases) + + def test5_4_count_gender(self): + gender = random.choice(['M', 'F']) + kit_count = self.ehrdb.count_gender(gender) + + self.ehrdb.cur.execute('SELECT COUNT(*) FROM mimic.PATIENTS WHERE GENDER = \'%s\'' % gender) + raw = self.ehrdb.cur.fetchall() + test_count = raw[0][0] + print('Gender:', gender, '\tCount:', str(test_count)) + + self.assertEqual(kit_count, test_count) + + +class t6(tests): + @unittest.skipIf("t6.test6_1_sentiment_classification" not in sys.argv, "Test 6_1 must be run explicitly due to verbosity.") + def test6_1_sentiment_classification(self): + import loader + + doc_id, text = select_ehr(self.ehrdb) + + x = input('GloVe or RoBERTa predictor [g=GloVe, r=RoBERTa]? ') + if x == 'g': + glove_predictor = loader.load_glove() + probs = glove_predictor.predict(text)['probs'] + elif x == 'r': + roberta_predictor = loader.load_roberta() + try: + probs = roberta_predictor.predict(text)['probs'] + except: + print('Document too long for RoBERTa model. Using GLoVe instead.') + glove_predictor = loader.load_glove() + probs = glove_predictor.predict(text)['probs'] + else: + sys.exit('Error: Must input \'g\' or \'r\'') + + classification = 'Positive' if probs[0] >= 0.5 else 'Negative' + print("Document ID: ", doc_id, "\tPredicted Sentiment: ", classification) + + @unittest.skipIf("t6.test6_2_ner" not in sys.argv, "Test 6_2 must be run explicitly due to verbosity.") + def test6_2_ner(self): + import loader + + doc_id, text = select_ehr(self.ehrdb) + + if os.path.exists("../allennlp/elmo-ner/whole_model.pt"): + predictor = loader.load_ner() + else: + predictor = loader.download_ner() + + text = self.ehrdb.get_document(int(doc_id)) + pred = predictor.predict(text) + # pred = predictor.predict("John likes and Bill hates ice cream") + print_results = input("Prediction complete. Print results? (y/n): ") + if print_results == "y": + print("Document ID: ", doc_id, " Results: ", pred['tags']) + + @unittest.skipIf("t6.test6_3_tokenize" not in sys.argv, "Test 6_3 must be run explicitly due to runtime.") + def test6_3_tokenize(self): + import torch + from transformers import BertTokenizer#, BertModel, BertForMaskedLM + + doc_id, text = select_ehr(self.ehrdb) + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + bert_tokenized_text = tokenizer.tokenize(text) + print('\n' + '-'*20 + 'text' + '-'*20) + print(text) + print('\n' + '-'*20 + 'Tokenized text from Huggingface BERT Tokenizer' + '-'*20) + print(bert_tokenized_text) + + + # library function + ehr_bert_tokenized_text = self.ehrdb.get_bert_tokenize(doc_id) + self.assertEqual(bert_tokenized_text, ehr_bert_tokenized_text) + + +class t7(tests): + # Summarization algorithms + #@unittest.skipIf("t7.test7_1_naive_bayes" not in sys.argv, "Test 7_1 must be run explicitly due to verbosity.") + def test7_1_naive_bayes(self): + from pubmed_naive_bayes import classify_nb + from get_pubmed_nb_data import build_vecs + from sklearn.naive_bayes import GaussianNB + + doc_id, text = select_ehr(self.ehrdb) + #body_type = input('Use Naive Bayes model trained from whole body sections or just their body introductions?\n\t'\ + # '[w=whole body, j=just intro, DEFAULT=just intro]: ') + + body_type = 'j' + + if body_type == 'w': + ending = 'body' + elif body_type in ['j', '']: + ending = 'intro' + else: + sys.exit('Error: Must input \'w\' or \'j.\'') + SUMM_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'summarization', 'pubmed_summarization')) + best_dir_name = get_nb_dir(ending, SUMM_DIR) + if not best_dir_name: + message = 'No Naive Bayes models of this type have been fit. '\ + 'Would you like to do so now?\n\t[DEFAULT=Yes] ' + #response = input(message) + + response = 'y' + + if response.lower() in ['y', 'yes', '']: + command = 'python ' + os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'summarization', 'pubmed_summarization', 'pubmed_naive_bayes.py')) + os.system(command) + best_dir_name = get_nb_dir(ending) + if response.lower() not in ['y', 'yes', ''] or not best_dir_name: + sys.exit('Exiting.') + + # Fits model to data + NB_DIR = os.path.join(SUMM_DIR, best_dir_name, 'nb') + with open(os.path.join(NB_DIR, 'feature_vecs.json'), 'r') as f: + data = json.load(f) + xtrain, ytrain = data['train_features'], data['train_outputs'] + gnb = GaussianNB() + gnb.fit(xtrain, ytrain) + + # Evaluates on model + tokenizer = nltk.data.load('tokenizers/punkt/english.pickle') + feature_vecs, _ = build_vecs(text, None, tokenizer) + PCT_SUM = 0.3 + preds = classify_nb(feature_vecs, PCT_SUM, gnb) + sents = tokenizer.tokenize(text) + summary = '' + for i in range(len(preds)): + if preds[i] == 1: + summary += sents[i] + + show_summary(doc_id, text, summary, 'Naive Bayes') + + #@unittest.skipIf("t7.test7_2_distilbart_summary" not in sys.argv, "Test 7_2 must be run explicitly due to runtime.") + def test7_2_distilbart_summary(self): + # Distilbart for summarization. Trained on CNN/ Daily Mail (~4x longer summaries than XSum) + doc_id, text = select_ehr(self.ehrdb, requires_long=True) + model_name = 'sshleifer/distilbart-cnn-12-6' + summary = self.ehrdb.summarize_huggingface(text, model_name) + + show_summary(doc_id, text, summary, model_name) + print('Number of Words in Full EHR: %d' % len(text.split())) + print('Number of Words in %s Summary: %d' % (model_name, len(summary.split()))) + + #@unittest.skipIf("t7.test7_3_t5_summary" not in sys.argv, "Test 7_3 must be run explicitly due to runtime.") + def test7_3_t5_summary(self): + # T5 for summarization. Trained on CNN/ Daily Mail + doc_id, text = select_ehr(self.ehrdb, requires_long=True) + model_name = 't5-small' + summary = self.ehrdb.summarize_huggingface(text, model_name) + + show_summary(doc_id, text, summary, model_name) + print('Number of Words in Full EHR: %d' % len(text.split())) + print('Number of Words in %s Summary: %d' % (model_name, len(summary.split()))) + +def testing(): + unittest.main() + +if __name__ == '__main__': + unittest.main()