[2d4573]: / tests / test_functions_no_output.ipynb

Download this file

953 lines (952 with data), 27.2 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "accbacfc",
   "metadata": {},
   "source": [
    "## EHRKit Demo\n",
    "\n",
    "In this notebook, we demonstrate some of the basic functions that you can utilize from the EHRKit."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9364f9f4",
   "metadata": {},
   "source": [
    "### Preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b745701",
   "metadata": {},
   "outputs": [],
   "source": [
    "import unittest\n",
    "import random\n",
    "import sys, os\n",
    "import re\n",
    "import nltk\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbcf7369",
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath('EHRKit'))))\n",
    "sys.path.append(os.path.abspath(os.path.join(os.path.dirname(''), '..', 'allennlp')))\n",
    "sys.path.append(os.path.abspath(os.path.join(os.path.dirname(''), '..', 'summarization', 'pubmed_summarization')))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccaf4e2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ehrkit import ehrkit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13222d2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from getpass import getpass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "722a048e",
   "metadata": {},
   "outputs": [],
   "source": [
    "try: \n",
    "    from config import USERNAME, PASSWORD\n",
    "except:\n",
    "    print(\"Please put your username and password in config.py\")\n",
    "    USERNAME = input('DB_username?')\n",
    "    PASSWORD = getpass('DB_password?')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "827f0ad3",
   "metadata": {},
   "outputs": [],
   "source": [
    "DOC_ID = 1354526 # Temporary!!!\n",
    "\n",
    "# Number of documents in NOTEEVENTS.\n",
    "NUM_DOCS = 2083180\n",
    "\n",
    "# Number of patients in PATIENTS.\n",
    "NUM_PATIENTS = 46520\n",
    "\n",
    "# Number of diagnoses in DIAGNOSES_ICD.\n",
    "NUM_DIAGS = 823933"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1654c5cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def select_ehr(ehrdb, requires_long=False, recursing=False):\n",
    "    if recursing:\n",
    "        doc_id = ''\n",
    "    else:\n",
    "        doc_id = input(\"MIMIC Document ID [press Enter for random]: \")\n",
    "    if doc_id == '':\n",
    "        # Picks random document\n",
    "        ehrdb.cur.execute(\"SELECT ROW_ID FROM mimic.NOTEEVENTS ORDER BY RAND() LIMIT 1\")\n",
    "        doc_id = ehrdb.cur.fetchall()[0][0]\n",
    "        text = ehrdb.get_document(int(doc_id))\n",
    "        if len(text.split()) > 200 or not requires_long:\n",
    "            return doc_id, text\n",
    "        else:\n",
    "            return select_ehr(ehrdb, requires_long, True)\n",
    "    else:\n",
    "        # Get inputted document\n",
    "        try:\n",
    "            text = ehrdb.get_document(int(doc_id))\n",
    "            return doc_id, text\n",
    "        except:\n",
    "            message = 'Error: There is no document with ID \\'' + doc_id + '\\' in mimic.NOTEEVENTS'\n",
    "            sys.exit(message)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43f13fc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_nb_dir(ending, SUMM_DIR):\n",
    "    # Gets path of Naive Bayes model trained on most examples\n",
    "    dir_nums = []\n",
    "    for dir in os.listdir(SUMM_DIR):\n",
    "        if os.path.isdir(os.path.join(SUMM_DIR, dir)) and dir.endswith('_exs_' + ending):\n",
    "            if os.path.exists(os.path.join(SUMM_DIR, dir, 'nb')):  \n",
    "                try:\n",
    "                    dir_nums.append(int(dir.split('_')[0]))\n",
    "                except:\n",
    "                    continue\n",
    "    if len(dir_nums) > 0:\n",
    "        best_dir_name = str(max(dir_nums)) + '_exs_' + ending\n",
    "        return best_dir_name\n",
    "    else:\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "272882d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_summary(doc_id, text, summary, model_name):\n",
    "    x = input('Show full EHR (DOC ID %s)? [DEFAULT=Yes]' % doc_id)\n",
    "    # x = ''\n",
    "    if x.lower() in ['y', 'yes', '']:\n",
    "        print('\\n\\n' + '-'*30 + 'Full EHR' + '-'*30)\n",
    "        print(text + '\\n')\n",
    "        print('-'*80 + '\\n\\n')\n",
    "\n",
    "    print('-'*30 + 'Predicted Summary ' + model_name + '-'*30)\n",
    "    print(summary)\n",
    "    print('-'*80 + '\\n\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bcc571c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class tests():\n",
    "    def __init__(self):\n",
    "        self.ehrdb = ehrkit.start_session(USERNAME, PASSWORD)\n",
    "        self.ehrdb.get_patients(3)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e665759",
   "metadata": {},
   "source": [
    "**1.1** Count the total number of patients."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "578dab47",
   "metadata": {},
   "source": [
    "### Test T1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ac56588",
   "metadata": {},
   "outputs": [],
   "source": [
    "class tests(tests):\n",
    "    def test1_1_count_patients(self):\n",
    "            kit_count = self.ehrdb.count_patients()\n",
    "            print(\"Patient count: \", kit_count)\n",
    "\n",
    "            self.ehrdb.cur.execute(\"SELECT COUNT(*) FROM mimic.PATIENTS\")\n",
    "            raw = self.ehrdb.cur.fetchall()\n",
    "            test_count = int(raw[0][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "709a054f",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = tests()\n",
    "test.test1_1_count_patients()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f04dc8bc",
   "metadata": {},
   "source": [
    "**1.3** Count the total number of sentences."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a77086e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class tests(tests):\n",
    "    def test1_3_note_info(self):\n",
    "        self.ehrdb.get_note_events()\n",
    "        print('output format: SUBJECT_ID, ROW_ID, NoteEvent length')\n",
    "        lens = [(patient.id, note[0], len(note[1])) for patient in self.ehrdb.patients.values() for note in patient.note_events]\n",
    "        print(lens)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48b05b86",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = tests()\n",
    "test.test1_3_note_info()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4c4c08c",
   "metadata": {},
   "source": [
    "**1.4** Print the record with the most sentences."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e053c31",
   "metadata": {},
   "outputs": [],
   "source": [
    "class tests(tests):\n",
    "    def test1_4_longest_note(self):\n",
    "            # Gets longest note among the patient notes queued by get_note_events()\n",
    "            self.ehrdb.get_note_events()\n",
    "            pid, rowid, doclen = self.ehrdb.longest_NE()\n",
    "            print('patient id is:', pid, '\\nNoteEvent id is:', rowid, '\\nlength: ', doclen)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "add51bdb",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = tests()\n",
    "test.test1_4_longest_note()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8bd66e85",
   "metadata": {},
   "source": [
    "### Test 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4e3b881",
   "metadata": {},
   "source": [
    "**2.1** Display document given a document ID."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe0a7b21",
   "metadata": {},
   "outputs": [],
   "source": [
    "class tests(tests):\n",
    "    def test2_1_print_note(self):\n",
    "            ### There are 2083180 patient records in NOTEEVENTS. ###\n",
    "            record_id = random.randint(1, NUM_DOCS + 1)\n",
    "            kit_rec = self.ehrdb.get_document(record_id)\n",
    "            print(\"Document with ID %d\\n: \" % record_id, kit_rec)\n",
    "\n",
    "            self.ehrdb.cur.execute(\"select TEXT from mimic.NOTEEVENTS where ROW_ID = %d\" % record_id)\n",
    "            test_rec = self.ehrdb.cur.fetchall()[0][0]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfe5c0cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = tests()\n",
    "test.test2_1_print_note()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52c4142a",
   "metadata": {},
   "source": [
    "**2.2** Count the number of documents associated with a given patient given patient ID."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f6208b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class tests(tests):\n",
    "    def test2_2_patient_info(self):\n",
    "        ### There are records from 46520 unique patients in MIMIC. ###\n",
    "        patient_id = random.randint(1, NUM_PATIENTS + 1)\n",
    "        kit_ids = self.ehrdb.get_all_patient_document_ids(patient_id)\n",
    "        print('Document IDs related to Patient %d: ' % patient_id, kit_ids)\n",
    "        print(\"Number of docs related to Patient %d: \" % patient_id, len(kit_ids))\n",
    "\n",
    "        self.ehrdb.cur.execute(\"select ROW_ID from mimic.NOTEEVENTS where SUBJECT_ID = %d\" % patient_id)\n",
    "        raw = self.ehrdb.cur.fetchall()\n",
    "        test_ids = ehrkit.flatten(raw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ae06b45",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = tests()\n",
    "test.test2_2_patient_info()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f571db23",
   "metadata": {},
   "source": [
    "**2.3** List all document IDs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bab71df9",
   "metadata": {},
   "outputs": [],
   "source": [
    "class tests(tests):\n",
    "    def test2_3_doc_ids(self):\n",
    "            kit_ids = self.ehrdb.list_all_document_ids()\n",
    "\n",
    "            self.ehrdb.cur.execute(\"select ROW_ID from mimic.NOTEEVENTS\")\n",
    "            raw = self.ehrdb.cur.fetchall()\n",
    "            test_ids = ehrkit.flatten(raw)\n",
    "            print('All document ids: (truncated)')\n",
    "            print(test_ids[:30])\n",
    "            print('...')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29d8fb2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = tests()\n",
    "test.test2_3_doc_ids()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "549c2a16",
   "metadata": {},
   "source": [
    "**2.4** List all patient IDs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1896914",
   "metadata": {},
   "outputs": [],
   "source": [
    "class tests(tests):\n",
    "    def test2_4_patient_ids(self):\n",
    "            kit_ids = self.ehrdb.list_all_patient_ids()\n",
    "\n",
    "            self.ehrdb.cur.execute(\"select SUBJECT_ID from mimic.PATIENTS\")\n",
    "            raw = self.ehrdb.cur.fetchall()\n",
    "            test_ids = ehrkit.flatten(raw)\n",
    "            print(\"All patient ids: (truncated)\")\n",
    "            print(test_ids[:30])\n",
    "            print('...')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43026f9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = tests()\n",
    "test.test2_4_patient_ids()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "898ddf27",
   "metadata": {},
   "source": [
    "**2.5** List all document IDs for a given admission date."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf636ec2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class tests(tests):\n",
    "    def test2_5_docs_on_date(self):\n",
    "        ### Select random date from a date in the database. \n",
    "        ### Dates are shifted to future but preserve time, weekday, and seasonality.\n",
    "        random_id = random.randint(1, NUM_DOCS + 1)\n",
    "        self.ehrdb.cur.execute(\"select CHARTDATE from mimic.NOTEEVENTS where ROW_ID = %d\" % random_id)\n",
    "        date = self.ehrdb.cur.fetchall()[0][0]\n",
    "\n",
    "        kit_ids = self.ehrdb.get_documents_d(date)\n",
    "\n",
    "        self.ehrdb.cur.execute(\"select ROW_ID from mimic.NOTEEVENTS where CHARTDATE = \\\"%s\\\"\" % date)\n",
    "        raw = self.ehrdb.cur.fetchall()\n",
    "        test_ids = ehrkit.flatten(raw)\n",
    "        print(f\"Selected date: {date}\")\n",
    "        print(f\"Test ids {test_ids[:30]} ...\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0da163eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = tests()\n",
    "test.test2_5_docs_on_date()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "378be2d9",
   "metadata": {},
   "source": [
    "### Test 3 ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b49343e",
   "metadata": {},
   "source": [
    "**3.1** Extract all abbreviations from a document, given the document ID."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "177712a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "class tests(tests):\n",
    "    def test3_1_extract_abbreviations(self):\n",
    "        # Defines abbreviation as a string of capitalized letters\n",
    "        random_id = random.randint(1, NUM_DOCS + 1)\n",
    "        print(\"Collecting abbreviations for document %d...\" % random_id)\n",
    "        kit_abbs = self.ehrdb.get_abbreviations(random_id)\n",
    "\n",
    "        sents = self.ehrdb.get_document_sents(random_id)\n",
    "        test_abbs = set()\n",
    "        for sent in sents:\n",
    "            for word in ehrkit.word_tokenize(sent):\n",
    "                print(word)\n",
    "                pattern = r'[A-Z]{2}'  # Only selects words in ALL CAPS\n",
    "                if re.match(pattern, word):\n",
    "                    test_abbs.add(word)\n",
    "\n",
    "        print(kit_abbs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "319d8a75",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = tests()\n",
    "test.test3_1_extract_abbreviations()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27f0fce0",
   "metadata": {},
   "source": [
    "**3.2** List all document IDs that include keywork \"meningitis\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3620b640",
   "metadata": {},
   "outputs": [],
   "source": [
    "class tests(tests):\n",
    "    def test3_2_docs_with_query(self):\n",
    "        query = \"meningitis\"\n",
    "        print('Printing a list of all document ids including query like ', query)\n",
    "        kit_ids = self.ehrdb.get_documents_q(query)\n",
    "        print(kit_ids[:30])  # Extremely long list of DOC_IDs\n",
    "        print(\"...\")\n",
    "\n",
    "        query = \"%\"+query+\"%\"\n",
    "        self.ehrdb.cur.execute(\"select ROW_ID from mimic.NOTEEVENTS where TEXT like \\'%s\\'\" % query)\n",
    "        raw = self.ehrdb.cur.fetchall()\n",
    "        test_ids = ehrkit.flatten(raw)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07163b53",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = tests()\n",
    "test.test3_2_docs_with_query()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4deaa30",
   "metadata": {},
   "source": [
    "**3.3** List all document IDs that include keywords \"Service: SURGERY\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "570b1ce7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class tests(tests):\n",
    "    def test3_3_query_docs(self):\n",
    "            ### Task 3.3 is the same as task 3.2 with a different query. ###\n",
    "            query = \"Service: SURGERY\"\n",
    "            print('Printing a list of all document ids including query like ', query)\n",
    "            kit_ids = self.ehrdb.get_documents_q(query)\n",
    "            print(kit_ids[:30])  # Extremely long list of DOC_IDs\n",
    "            print(\"...\")\n",
    "\n",
    "            query = \"%\"+query+\"%\"\n",
    "            self.ehrdb.cur.execute(\"select ROW_ID from mimic.NOTEEVENTS where TEXT like \\'%s\\'\" % query)\n",
    "            raw = self.ehrdb.cur.fetchall()\n",
    "            test_ids = ehrkit.flatten(raw)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2923f60a",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = tests()\n",
    "test.test3_3_query_docs()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b073b5b6",
   "metadata": {},
   "source": [
    "**3.4** Given a document ID, show a numbered list of all sentences in that document."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e4eb5dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "class tests(tests):\n",
    "    def test3_4_doc_sentences(self):\n",
    "        doc_id = random.randint(1, NUM_DOCS + 1)\n",
    "        print('Kit function printing a numbered list of all sentences in document %d' % doc_id)\n",
    "        # MIMIC EHRs are very messy and sentence tokenizaton often doesn't work\n",
    "        kit_doc = self.ehrdb.get_document_sents(doc_id)\n",
    "        ehrkit.numbered_print(kit_doc)\n",
    "\n",
    "        self.ehrdb.cur.execute(\"select TEXT from mimic.NOTEEVENTS where ROW_ID = %d \" % doc_id)\n",
    "        raw = self.ehrdb.cur.fetchall()\n",
    "        test_doc = ehrkit.sent_tokenize(raw[0][0])\n",
    "        print(test_doc)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d09aeace",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = tests()\n",
    "test.test3_4_doc_sentences()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8bba4817",
   "metadata": {},
   "source": [
    "**3.5** Count the number of prescriptions for each unique medication."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f87aaac",
   "metadata": {},
   "outputs": [],
   "source": [
    "class tests(tests):\n",
    "    def test3_7_medications(self):\n",
    "            kit_meds = self.ehrdb.count_all_prescriptions()\n",
    "\n",
    "            test_meds = {}\n",
    "            self.ehrdb.cur.execute(\"select DRUG from mimic.PRESCRIPTIONS\")\n",
    "            raw = self.ehrdb.cur.fetchall()\n",
    "            meds_list = ehrkit.flatten(raw)\n",
    "            for med in meds_list:\n",
    "                if med in test_meds:\n",
    "                    test_meds[med] += 1\n",
    "                else:\n",
    "                    test_meds[med] = 1\n",
    "\n",
    "            print(meds_list[:30])\n",
    "            print(\"...\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5464edb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = tests()\n",
    "test.test3_7_medications()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ee9f298",
   "metadata": {},
   "source": [
    "### Test T5 ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99a030b2",
   "metadata": {},
   "source": [
    "**5.4** Count how many patients are labeled as \"male\" or \"female\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a05ce6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class tests(tests):\n",
    "    def test5_4_count_gender(self):\n",
    "        gender = random.choice(['M', 'F'])\n",
    "        kit_count = self.ehrdb.count_gender(gender)\n",
    "\n",
    "        self.ehrdb.cur.execute('SELECT COUNT(*) FROM mimic.PATIENTS WHERE GENDER = \\'%s\\'' % gender)\n",
    "        raw = self.ehrdb.cur.fetchall()\n",
    "        test_count = raw[0][0]\n",
    "        print('Gender:', gender, '\\tCount:', str(test_count))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a1459fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = tests()\n",
    "test.test5_4_count_gender()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5bf4ca7a",
   "metadata": {},
   "source": [
    "### Test T7 ###"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "936f87bc",
   "metadata": {},
   "source": [
    "**7.1** Creates extractive summary of an EHR with Naive Bayes Algorithm trained on PubMed articles."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40764971",
   "metadata": {},
   "outputs": [],
   "source": [
    "class tests(tests):\n",
    "    def test7_1_naive_bayes(self):\n",
    "        from pubmed_naive_bayes import classify_nb\n",
    "        from get_pubmed_nb_data import build_vecs\n",
    "        from sklearn.naive_bayes import GaussianNB\n",
    "\n",
    "        doc_id, text = select_ehr(self.ehrdb)\n",
    "        body_type = input('Use Naive Bayes model trained from whole body sections or just their body introductions?\\n\\t'\\\n",
    "                        '[w=whole body, j=just intro, DEFAULT=just intro]: ')\n",
    "        if body_type == 'w':\n",
    "            ending = 'body'\n",
    "        elif body_type in ['j', '']:\n",
    "            ending = 'intro'\n",
    "        else:\n",
    "            sys.exit('Error: Must input \\'w\\' or \\'j.\\'')\n",
    "        SUMM_DIR = os.path.abspath(os.path.join(os.path.dirname('EHRKit'), '..', 'summarization', 'pubmed_summarization'))\n",
    "        best_dir_name = get_nb_dir(ending, SUMM_DIR)\n",
    "        if not best_dir_name:\n",
    "            message = 'No Naive Bayes models of this type have been fit. '\\\n",
    "                        'Would you like to do so now?\\n\\t[DEFAULT=Yes] '\n",
    "            response = input(message)\n",
    "            if response.lower() in ['y', 'yes', '']:\n",
    "                command = 'python ' + os.path.abspath(os.path.join(os.path.dirname('EHRKit'), '..', 'summarization', 'pubmed_summarization', 'pubmed_naive_bayes.py'))\n",
    "                os.system(command)\n",
    "                best_dir_name = get_nb_dir(ending)\n",
    "            if response.lower() not in ['y', 'yes', ''] or not best_dir_name:\n",
    "                sys.exit('Exiting.')\n",
    "\n",
    "        # Fits model to data        \n",
    "        NB_DIR = os.path.join(SUMM_DIR, best_dir_name, 'nb')\n",
    "        with open(os.path.join(NB_DIR, 'feature_vecs.json'), 'r') as f:\n",
    "            data = json.load(f)\n",
    "        xtrain, ytrain = data['train_features'], data['train_outputs']\n",
    "        gnb = GaussianNB()\n",
    "        gnb.fit(xtrain, ytrain)\n",
    "\n",
    "        # Evaluates on model\n",
    "        tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')\n",
    "        feature_vecs, _ = build_vecs(text, None, tokenizer)\n",
    "        PCT_SUM = 0.3\n",
    "        preds = classify_nb(feature_vecs, PCT_SUM, gnb)\n",
    "        sents = tokenizer.tokenize(text)\n",
    "        summary = ''\n",
    "        for i in range(len(preds)):\n",
    "            if preds[i] == 1:\n",
    "                summary += sents[i]\n",
    "\n",
    "        show_summary(doc_id, text, summary, 'Naive Bayes')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8e05bc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = tests()\n",
    "test.test7_1_naive_bayes()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa265304",
   "metadata": {},
   "source": [
    "**7.2** Generates abstractive summary of an EHR with pre-trained Distilbart model from Huggingface."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "992ce1f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class tests(tests):\n",
    "    def test7_2_distilbart_summary(self):\n",
    "            # Distilbart for summarization. Trained on CNN/ Daily Mail (~4x longer summaries than XSum)\n",
    "            doc_id, text = select_ehr(self.ehrdb, requires_long=True)\n",
    "            model_name = 'sshleifer/distilbart-cnn-12-6'\n",
    "            summary = self.ehrdb.summarize_huggingface(text, model_name)\n",
    "\n",
    "            show_summary(doc_id, text, summary, model_name)\n",
    "            print('Number of Words in Full EHR: %d' % len(text.split()))\n",
    "            print('Number of Words in %s Summary: %d' % (model_name, len(summary.split())))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bd82bbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = tests()\n",
    "test.test7_2_distilbart_summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f678ea5",
   "metadata": {},
   "source": [
    "**7.3** Generates abstractive summary of an EHR with pre-trained T5 model from Huggingface."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27653251",
   "metadata": {},
   "outputs": [],
   "source": [
    "class tests(tests):\n",
    "    def test7_3_t5_summary(self):\n",
    "            # T5 for summarization. Trained on CNN/ Daily Mail\n",
    "            doc_id, text = select_ehr(self.ehrdb, requires_long=True)\n",
    "            model_name = 't5-small'\n",
    "            summary = self.ehrdb.summarize_huggingface(text, model_name)\n",
    "\n",
    "            show_summary(doc_id, text, summary, model_name)\n",
    "            print('Number of Words in Full EHR: %d' % len(text.split()))\n",
    "            print('Number of Words in %s Summary: %d' % (model_name, len(summary.split())))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e60e48d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = tests()\n",
    "test.test7_3_t5_summary()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}