Diff of /tests/tests.py [000000] .. [2d4573]

Switch to side-by-side view

--- a
+++ b/tests/tests.py
@@ -0,0 +1,472 @@
+import unittest
+import random
+import sys, os
+import re
+import nltk
+import json
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'allennlp')))
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'summarization', 'pubmed_summarization')))
+# print(sys.path)
+from ehrkit import ehrkit
+from getpass import getpass
+
+try: 
+    from config import USERNAME, PASSWORD
+except:
+    print("Please put your username and password in config.py")
+    USERNAME = input('DB_username?')
+    PASSWORD = getpass('DB_password?')
+
+
+DOC_ID = 1354526 # Temporary!!!
+
+
+# Number of documents in NOTEEVENTS.
+NUM_DOCS = 2083180
+
+# Number of patients in PATIENTS.
+NUM_PATIENTS = 46520
+
+# Number of diagnoses in DIAGNOSES_ICD.
+NUM_DIAGS = 823933
+
+
+def select_ehr(ehrdb, requires_long=False, recursing=False):
+    if recursing:
+        doc_id = ''
+    else:
+        #doc_id = input("MIMIC Document ID [press Enter for random]: ")
+        
+        doc_id = ''
+    
+    if doc_id == '':
+        # Picks random document
+        ehrdb.cur.execute("SELECT ROW_ID FROM mimic.NOTEEVENTS ORDER BY RAND() LIMIT 1")
+        doc_id = ehrdb.cur.fetchall()[0][0]
+        text = ehrdb.get_document(int(doc_id))
+        if len(text.split()) > 200 or not requires_long:
+            return doc_id, text
+        else:
+            return select_ehr(ehrdb, requires_long, True)
+    else:
+        # Get inputted document
+        try:
+            text = ehrdb.get_document(int(doc_id))
+            return doc_id, text
+        except:
+            message = 'Error: There is no document with ID \'' + doc_id + '\' in mimic.NOTEEVENTS'
+            sys.exit(message)
+
+
+def get_nb_dir(ending, SUMM_DIR):
+    # Gets path of Naive Bayes model trained on most examples
+    dir_nums = []
+    for dir in os.listdir(SUMM_DIR):
+        if os.path.isdir(os.path.join(SUMM_DIR, dir)) and dir.endswith('_exs_' + ending):
+            if os.path.exists(os.path.join(SUMM_DIR, dir, 'nb')):  
+                try:
+                    dir_nums.append(int(dir.split('_')[0]))
+                except:
+                    continue
+    if len(dir_nums) > 0:
+        best_dir_name = str(max(dir_nums)) + '_exs_' + ending
+        return best_dir_name
+    else:
+        return None
+
+def show_summary(doc_id, text, summary, model_name):
+    #x = input('Show full EHR (DOC ID %s)? [DEFAULT=Yes]' % doc_id)
+    x = ''
+    if x.lower() in ['y', 'yes', '']:
+        print('\n\n' + '-'*30 + 'Full EHR' + '-'*30)
+        print(text + '\n')
+        print('-'*80 + '\n\n')
+
+    print('-'*30 + 'Predicted Summary ' + model_name + '-'*30)
+    print(summary)
+    print('-'*80 + '\n\n')
+
+
+class tests(unittest.TestCase):
+    def setUp(self):
+        self.ehrdb = ehrkit.start_session(USERNAME, PASSWORD)
+        self.ehrdb.get_patients(3)
+
+
+''' Runs tests 1.1-1.4 '''
+class t1(tests):
+    def test1_1_count_patients(self):
+        kit_count = self.ehrdb.count_patients()
+        print("Patient count: ", kit_count)
+
+        self.ehrdb.cur.execute("SELECT COUNT(*) FROM mimic.PATIENTS")
+        raw = self.ehrdb.cur.fetchall()
+        test_count = int(raw[0][0])
+
+        self.assertEqual(test_count, kit_count)
+
+    # def test1_2_count_docs(self):
+    #     # Fails! count_docs returns 1573339, but mimic.NOTEEVENTS has 2083180 documents. 
+    #     # TO DO: Fix whatever is wrong here
+    #     kit_count = self.ehrdb.count_docs(['NOTEEVENTS'])
+    #     print("Document count: ", kit_count)
+
+    #     self.ehrdb.cur.execute("SELECT COUNT(*) FROM mimic.NOTEEVENTS")
+    #     raw = self.ehrdb.cur.fetchall()
+    #     test_count = int(raw[0][0])
+
+    #     self.assertEqual(test_count, kit_count)
+
+    def test1_3_note_info(self):
+        self.ehrdb.get_note_events()
+        print('output format: SUBJECT_ID, ROW_ID, NoteEvent length')
+        lens = [(patient.id, note[0], len(note[1])) for patient in self.ehrdb.patients.values() for note in patient.note_events]
+        print(lens)
+
+        # placeholder, this output cannot be checked easily
+        self.assertEqual(1, 1)
+
+    def test1_4_longest_note(self):
+        # Gets longest note among the patient notes queued by get_note_events()
+        self.ehrdb.get_note_events()
+        pid, rowid, doclen = self.ehrdb.longest_NE()
+        print('patient id is:', pid, 'NoteEvent id is:', rowid, 'length: ', doclen)
+
+        # placeholder, this output cannot be checked easily
+        self.assertEqual(1, 1)
+
+
+class t2(tests):
+    def test2_1_print_note(self):
+        ### There are 2083180 patient records in NOTEEVENTS. ###
+        record_id = random.randint(1, NUM_DOCS + 1)
+        kit_rec = self.ehrdb.get_document(record_id)
+        print("Document with ID %d\n: " % record_id, kit_rec)
+
+        self.ehrdb.cur.execute("select TEXT from mimic.NOTEEVENTS where ROW_ID = %d" % record_id)
+        test_rec = self.ehrdb.cur.fetchall()[0][0]
+
+        self.assertEqual(kit_rec, test_rec)
+
+    def test2_2_patient_info(self):
+        ### There are records from 46520 unique patients in MIMIC. ###
+        patient_id = random.randint(1, NUM_PATIENTS + 1)
+        kit_ids = self.ehrdb.get_all_patient_document_ids(patient_id)
+        print('Document IDs related to Patient %d: ' % patient_id, kit_ids)
+        print("Number of docs related to Patient %d: " % patient_id, len(kit_ids))
+
+        self.ehrdb.cur.execute("select ROW_ID from mimic.NOTEEVENTS where SUBJECT_ID = %d" % patient_id)
+        raw = self.ehrdb.cur.fetchall()
+        test_ids = ehrkit.flatten(raw)
+
+        self.assertEqual(kit_ids, test_ids)
+
+    #@unittest.skipIf("t2.test2_3" not in sys.argv, "Test 2_3 must be run explicitly due to runtime.")
+    def test2_3_doc_ids(self):
+        kit_ids = self.ehrdb.list_all_document_ids()
+
+        self.ehrdb.cur.execute("select ROW_ID from mimic.NOTEEVENTS")
+        raw = self.ehrdb.cur.fetchall()
+        test_ids = ehrkit.flatten(raw)
+        print('test_ids')
+        print(test_ids[:30])
+        print('...')
+
+        self.assertEqual(kit_ids, test_ids)
+
+    def test2_4_patient_ids(self):
+        kit_ids = self.ehrdb.list_all_patient_ids()
+
+        self.ehrdb.cur.execute("select SUBJECT_ID from mimic.PATIENTS")
+        raw = self.ehrdb.cur.fetchall()
+        test_ids = ehrkit.flatten(raw)
+        print(test_ids[:30])
+        print('...')
+
+        self.assertEqual(kit_ids, test_ids)
+
+    #@unittest.skipIf("t2.test2_5" not in sys.argv, "Test 2_5 must be run explicitly due to runtime.")
+    def test2_5_docs_on_date(self):
+        ### Select random date from a date in the database. 
+        ### Dates are shifted to future but preserve time, weekday, and seasonality.
+        random_id = random.randint(1, NUM_DOCS + 1)
+        self.ehrdb.cur.execute("select CHARTDATE from mimic.NOTEEVENTS where ROW_ID = %d" % random_id)
+        date = self.ehrdb.cur.fetchall()[0][0]
+
+        kit_ids = self.ehrdb.get_documents_d(date)
+
+        self.ehrdb.cur.execute("select ROW_ID from mimic.NOTEEVENTS where CHARTDATE = \"%s\"" % date)
+        raw = self.ehrdb.cur.fetchall()
+        test_ids = ehrkit.flatten(raw)
+        print(f"Selected date: {date}")
+        print(f"Test ids {test_ids[:30]} ...")
+
+        self.assertEqual(kit_ids, test_ids)
+
+
+class t3(tests):
+    def test3_1_extract_abbreviations(self):
+        # Defines abbreviation as a string of capitalized letters
+        random_id = random.randint(1, NUM_DOCS + 1)
+        print("Collecting abbreviations for document %d..." % random_id)
+        kit_abbs = self.ehrdb.get_abbreviations(random_id)
+
+        sents = self.ehrdb.get_document_sents(random_id)
+        test_abbs = set()
+        for sent in sents:
+            for word in ehrkit.word_tokenize(sent):
+                print(word)
+                pattern = r'[A-Z]{2}'  # Only selects words in ALL CAPS
+                if re.match(pattern, word):
+                    test_abbs.add(word)
+
+        print(kit_abbs)
+
+        self.assertEqual(kit_abbs, list(test_abbs))
+
+    #@unittest.skipIf("t3.test3_2" not in sys.argv, "Test 3_2 must be run explicitly due to runtime.")
+    def test3_2_docs_with_query(self):
+        query = "meningitis"
+        print('Printing a list of all document ids including query like ', query)
+        kit_ids = self.ehrdb.get_documents_q(query)
+        print(kit_ids[:30])  # Extremely long list of DOC_IDs
+        print("...")
+
+        query = "%"+query+"%"
+        self.ehrdb.cur.execute("select ROW_ID from mimic.NOTEEVENTS where TEXT like \'%s\'" % query)
+        raw = self.ehrdb.cur.fetchall()
+        test_ids = ehrkit.flatten(raw)
+
+        self.assertEqual(kit_ids, test_ids)
+
+    #@unittest.skipIf("t3.test3_3" not in sys.argv, "Test 3_3 must be run explicitly due to runtime. Also, this is essentially a duplicate of task 3.2.")
+    def test3_3_query_docs(self):
+        ### Task 3.3 is the same as task 3.2 with a different query. ###
+        query = "Service: SURGERY"
+        print('Printing a list of all document ids including query like ', query)
+        kit_ids = self.ehrdb.get_documents_q(query)
+        print(kit_ids[:30])  # Extremely long list of DOC_IDs
+        print("...")
+
+        query = "%"+query+"%"
+        self.ehrdb.cur.execute("select ROW_ID from mimic.NOTEEVENTS where TEXT like \'%s\'" % query)
+        raw = self.ehrdb.cur.fetchall()
+        test_ids = ehrkit.flatten(raw)
+
+        self.assertEqual(kit_ids, test_ids)
+
+    def test3_4_doc_sentences(self):
+        doc_id = random.randint(1, NUM_DOCS + 1)
+        print('Kit function printing a numbered list of all sentences in document %d' % doc_id)
+        # MIMIC EHRs are very messy and sentence tokenizaton often doesn't work
+        kit_doc = self.ehrdb.get_document_sents(doc_id)
+        ehrkit.numbered_print(kit_doc)
+
+        self.ehrdb.cur.execute("select TEXT from mimic.NOTEEVENTS where ROW_ID = %d " % doc_id)
+        raw = self.ehrdb.cur.fetchall()
+        test_doc = ehrkit.sent_tokenize(raw[0][0])
+        print(test_doc)
+
+        self.assertEqual(kit_doc, test_doc)
+
+    #@unittest.skipIf("t3.test3_7" not in sys.argv, "Test 3_7 must be run explicitly due to runtime.")
+    def test3_7_medications(self):
+        kit_meds = self.ehrdb.count_all_prescriptions()
+
+        test_meds = {}
+        self.ehrdb.cur.execute("select DRUG from mimic.PRESCRIPTIONS")
+        raw = self.ehrdb.cur.fetchall()
+        meds_list = ehrkit.flatten(raw)
+        for med in meds_list:
+            if med in test_meds:
+                test_meds[med] += 1
+            else:
+                test_meds[med] = 1
+
+        print(meds_list[:30])
+        print("...")
+
+        self.assertEqual(kit_meds, test_meds)
+
+
+class t4(tests):
+    @unittest.skip("Task 4.1 is not ready to be tested yet.")
+    def test4_1(self):
+        d = self.ehrdb.get_documents_icd9()
+        print(d)
+        self.assertIsNotNone(d['code'])
+
+    @unittest.skip("Task 4.4 is not ready to be tested yet.")
+    def test4_4(self):
+        pass
+
+
+class t5(tests):
+    @unittest.skipIf("t5.test5_1" not in sys.argv, "Test 5_1 must be run explicitly due to runtime.")
+    def test5_1_extract_phrases(self):
+        doc_id = random.randint(1, NUM_DOCS + 1)
+        kit_phrases = self.ehrdb.extract_phrases(doc_id)
+
+        print("Testing task 5.1\n Check phrases manually: ", kit_phrases)
+
+        self.assertIsNotNone(kit_phrases)
+
+    def test5_4_count_gender(self):
+        gender = random.choice(['M', 'F'])
+        kit_count = self.ehrdb.count_gender(gender)
+
+        self.ehrdb.cur.execute('SELECT COUNT(*) FROM mimic.PATIENTS WHERE GENDER = \'%s\'' % gender)
+        raw = self.ehrdb.cur.fetchall()
+        test_count = raw[0][0]
+        print('Gender:', gender, '\tCount:', str(test_count))
+
+        self.assertEqual(kit_count, test_count)
+
+
+class t6(tests):
+    @unittest.skipIf("t6.test6_1_sentiment_classification" not in sys.argv, "Test 6_1 must be run explicitly due to verbosity.")
+    def test6_1_sentiment_classification(self):
+        import loader
+
+        doc_id, text = select_ehr(self.ehrdb)
+
+        x = input('GloVe or RoBERTa predictor [g=GloVe, r=RoBERTa]? ')
+        if x == 'g':
+            glove_predictor = loader.load_glove()
+            probs = glove_predictor.predict(text)['probs']
+        elif x == 'r':
+            roberta_predictor = loader.load_roberta()
+            try:
+                probs = roberta_predictor.predict(text)['probs']
+            except:
+                print('Document too long for RoBERTa model. Using GLoVe instead.')
+                glove_predictor = loader.load_glove()
+                probs = glove_predictor.predict(text)['probs']
+        else:
+            sys.exit('Error: Must input \'g\' or  \'r\'')
+
+        classification = 'Positive' if probs[0] >= 0.5 else 'Negative'
+        print("Document ID: ", doc_id, "\tPredicted Sentiment: ", classification)
+
+    @unittest.skipIf("t6.test6_2_ner" not in sys.argv, "Test 6_2 must be run explicitly due to verbosity.")
+    def test6_2_ner(self):
+        import loader
+
+        doc_id, text = select_ehr(self.ehrdb)
+
+        if os.path.exists("../allennlp/elmo-ner/whole_model.pt"):
+            predictor = loader.load_ner()
+        else:
+            predictor = loader.download_ner()
+
+        text = self.ehrdb.get_document(int(doc_id))
+        pred = predictor.predict(text)
+        # pred = predictor.predict("John likes and Bill hates ice cream")
+        print_results = input("Prediction complete. Print results? (y/n): ")
+        if print_results == "y":
+            print("Document ID: ", doc_id, "  Results: ", pred['tags'])
+    
+    @unittest.skipIf("t6.test6_3_tokenize" not in sys.argv, "Test 6_3 must be run explicitly due to runtime.")
+    def test6_3_tokenize(self):
+        import torch
+        from transformers import BertTokenizer#, BertModel, BertForMaskedLM
+
+        doc_id, text = select_ehr(self.ehrdb)
+        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+        bert_tokenized_text = tokenizer.tokenize(text)
+        print('\n' + '-'*20 + 'text' + '-'*20)
+        print(text)
+        print('\n' + '-'*20 + 'Tokenized text from Huggingface BERT Tokenizer' + '-'*20)    
+        print(bert_tokenized_text)
+
+
+        # library function
+        ehr_bert_tokenized_text = self.ehrdb.get_bert_tokenize(doc_id)
+        self.assertEqual(bert_tokenized_text, ehr_bert_tokenized_text)
+
+
+class t7(tests):
+    # Summarization algorithms
+    #@unittest.skipIf("t7.test7_1_naive_bayes" not in sys.argv, "Test 7_1 must be run explicitly due to verbosity.")
+    def test7_1_naive_bayes(self):
+        from pubmed_naive_bayes import classify_nb
+        from get_pubmed_nb_data import build_vecs
+        from sklearn.naive_bayes import GaussianNB
+
+        doc_id, text = select_ehr(self.ehrdb)
+        #body_type = input('Use Naive Bayes model trained from whole body sections or just their body introductions?\n\t'\
+        #                '[w=whole body, j=just intro, DEFAULT=just intro]: ')
+        
+        body_type = 'j'
+
+        if body_type == 'w':
+            ending = 'body'
+        elif body_type in ['j', '']:
+            ending = 'intro'
+        else:
+            sys.exit('Error: Must input \'w\' or \'j.\'')
+        SUMM_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'summarization', 'pubmed_summarization'))
+        best_dir_name = get_nb_dir(ending, SUMM_DIR)
+        if not best_dir_name:
+            message = 'No Naive Bayes models of this type have been fit. '\
+                        'Would you like to do so now?\n\t[DEFAULT=Yes] '
+            #response = input(message)
+            
+            response = 'y'
+
+            if response.lower() in ['y', 'yes', '']:
+                command = 'python ' + os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'summarization', 'pubmed_summarization', 'pubmed_naive_bayes.py'))
+                os.system(command)
+                best_dir_name = get_nb_dir(ending)
+            if response.lower() not in ['y', 'yes', ''] or not best_dir_name:
+                sys.exit('Exiting.')
+
+        # Fits model to data        
+        NB_DIR = os.path.join(SUMM_DIR, best_dir_name, 'nb')
+        with open(os.path.join(NB_DIR, 'feature_vecs.json'), 'r') as f:
+            data = json.load(f)
+        xtrain, ytrain = data['train_features'], data['train_outputs']
+        gnb = GaussianNB()
+        gnb.fit(xtrain, ytrain)
+
+        # Evaluates on model
+        tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
+        feature_vecs, _ = build_vecs(text, None, tokenizer)
+        PCT_SUM = 0.3
+        preds = classify_nb(feature_vecs, PCT_SUM, gnb)
+        sents = tokenizer.tokenize(text)
+        summary = ''
+        for i in range(len(preds)):
+            if preds[i] == 1:
+                summary += sents[i]
+
+        show_summary(doc_id, text, summary, 'Naive Bayes')
+        
+    #@unittest.skipIf("t7.test7_2_distilbart_summary" not in sys.argv, "Test 7_2 must be run explicitly due to runtime.")
+    def test7_2_distilbart_summary(self):
+        # Distilbart for summarization. Trained on CNN/ Daily Mail (~4x longer summaries than XSum)
+        doc_id, text = select_ehr(self.ehrdb, requires_long=True)
+        model_name = 'sshleifer/distilbart-cnn-12-6'
+        summary = self.ehrdb.summarize_huggingface(text, model_name)
+
+        show_summary(doc_id, text, summary, model_name)
+        print('Number of Words in Full EHR: %d' % len(text.split()))
+        print('Number of Words in %s Summary: %d' % (model_name, len(summary.split())))
+
+    #@unittest.skipIf("t7.test7_3_t5_summary" not in sys.argv, "Test 7_3 must be run explicitly due to runtime.")
+    def test7_3_t5_summary(self):
+        # T5 for summarization. Trained on CNN/ Daily Mail
+        doc_id, text = select_ehr(self.ehrdb, requires_long=True)
+        model_name = 't5-small'
+        summary = self.ehrdb.summarize_huggingface(text, model_name)
+
+        show_summary(doc_id, text, summary, model_name)
+        print('Number of Words in Full EHR: %d' % len(text.split()))
+        print('Number of Words in %s Summary: %d' % (model_name, len(summary.split())))
+
+def testing():
+    unittest.main()
+
+if __name__ == '__main__':
+    unittest.main()