953 lines (952 with data), 27.2 kB
{
"cells": [
{
"cell_type": "markdown",
"id": "accbacfc",
"metadata": {},
"source": [
"## EHRKit Demo\n",
"\n",
"In this notebook, we demonstrate some of the basic functions that you can utilize from the EHRKit."
]
},
{
"cell_type": "markdown",
"id": "9364f9f4",
"metadata": {},
"source": [
"### Preparation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6b745701",
"metadata": {},
"outputs": [],
"source": [
"import unittest\n",
"import random\n",
"import sys, os\n",
"import re\n",
"import nltk\n",
"import json"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dbcf7369",
"metadata": {},
"outputs": [],
"source": [
"sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath('EHRKit'))))\n",
"sys.path.append(os.path.abspath(os.path.join(os.path.dirname(''), '..', 'allennlp')))\n",
"sys.path.append(os.path.abspath(os.path.join(os.path.dirname(''), '..', 'summarization', 'pubmed_summarization')))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ccaf4e2d",
"metadata": {},
"outputs": [],
"source": [
"from ehrkit import ehrkit"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "13222d2c",
"metadata": {},
"outputs": [],
"source": [
"from getpass import getpass"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "722a048e",
"metadata": {},
"outputs": [],
"source": [
"try: \n",
" from config import USERNAME, PASSWORD\n",
"except:\n",
" print(\"Please put your username and password in config.py\")\n",
" USERNAME = input('DB_username?')\n",
" PASSWORD = getpass('DB_password?')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "827f0ad3",
"metadata": {},
"outputs": [],
"source": [
"DOC_ID = 1354526 # Temporary!!!\n",
"\n",
"# Number of documents in NOTEEVENTS.\n",
"NUM_DOCS = 2083180\n",
"\n",
"# Number of patients in PATIENTS.\n",
"NUM_PATIENTS = 46520\n",
"\n",
"# Number of diagnoses in DIAGNOSES_ICD.\n",
"NUM_DIAGS = 823933"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1654c5cb",
"metadata": {},
"outputs": [],
"source": [
"def select_ehr(ehrdb, requires_long=False, recursing=False):\n",
" if recursing:\n",
" doc_id = ''\n",
" else:\n",
" doc_id = input(\"MIMIC Document ID [press Enter for random]: \")\n",
" if doc_id == '':\n",
" # Picks random document\n",
" ehrdb.cur.execute(\"SELECT ROW_ID FROM mimic.NOTEEVENTS ORDER BY RAND() LIMIT 1\")\n",
" doc_id = ehrdb.cur.fetchall()[0][0]\n",
" text = ehrdb.get_document(int(doc_id))\n",
" if len(text.split()) > 200 or not requires_long:\n",
" return doc_id, text\n",
" else:\n",
" return select_ehr(ehrdb, requires_long, True)\n",
" else:\n",
" # Get inputted document\n",
" try:\n",
" text = ehrdb.get_document(int(doc_id))\n",
" return doc_id, text\n",
" except:\n",
" message = 'Error: There is no document with ID \\'' + doc_id + '\\' in mimic.NOTEEVENTS'\n",
" sys.exit(message)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "43f13fc3",
"metadata": {},
"outputs": [],
"source": [
"def get_nb_dir(ending, SUMM_DIR):\n",
" # Gets path of Naive Bayes model trained on most examples\n",
" dir_nums = []\n",
" for dir in os.listdir(SUMM_DIR):\n",
" if os.path.isdir(os.path.join(SUMM_DIR, dir)) and dir.endswith('_exs_' + ending):\n",
" if os.path.exists(os.path.join(SUMM_DIR, dir, 'nb')): \n",
" try:\n",
" dir_nums.append(int(dir.split('_')[0]))\n",
" except:\n",
" continue\n",
" if len(dir_nums) > 0:\n",
" best_dir_name = str(max(dir_nums)) + '_exs_' + ending\n",
" return best_dir_name\n",
" else:\n",
" return None"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "272882d8",
"metadata": {},
"outputs": [],
"source": [
"def show_summary(doc_id, text, summary, model_name):\n",
" x = input('Show full EHR (DOC ID %s)? [DEFAULT=Yes]' % doc_id)\n",
" # x = ''\n",
" if x.lower() in ['y', 'yes', '']:\n",
" print('\\n\\n' + '-'*30 + 'Full EHR' + '-'*30)\n",
" print(text + '\\n')\n",
" print('-'*80 + '\\n\\n')\n",
"\n",
" print('-'*30 + 'Predicted Summary ' + model_name + '-'*30)\n",
" print(summary)\n",
" print('-'*80 + '\\n\\n')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0bcc571c",
"metadata": {},
"outputs": [],
"source": [
"class tests():\n",
" def __init__(self):\n",
" self.ehrdb = ehrkit.start_session(USERNAME, PASSWORD)\n",
" self.ehrdb.get_patients(3)\n",
" "
]
},
{
"cell_type": "markdown",
"id": "5e665759",
"metadata": {},
"source": [
"**1.1** Count the total number of patients."
]
},
{
"cell_type": "markdown",
"id": "578dab47",
"metadata": {},
"source": [
"### Test T1"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7ac56588",
"metadata": {},
"outputs": [],
"source": [
"class tests(tests):\n",
" def test1_1_count_patients(self):\n",
" kit_count = self.ehrdb.count_patients()\n",
" print(\"Patient count: \", kit_count)\n",
"\n",
" self.ehrdb.cur.execute(\"SELECT COUNT(*) FROM mimic.PATIENTS\")\n",
" raw = self.ehrdb.cur.fetchall()\n",
" test_count = int(raw[0][0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "709a054f",
"metadata": {},
"outputs": [],
"source": [
"test = tests()\n",
"test.test1_1_count_patients()"
]
},
{
"cell_type": "markdown",
"id": "f04dc8bc",
"metadata": {},
"source": [
"**1.3** Count the total number of sentences."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a77086e",
"metadata": {},
"outputs": [],
"source": [
"class tests(tests):\n",
" def test1_3_note_info(self):\n",
" self.ehrdb.get_note_events()\n",
" print('output format: SUBJECT_ID, ROW_ID, NoteEvent length')\n",
" lens = [(patient.id, note[0], len(note[1])) for patient in self.ehrdb.patients.values() for note in patient.note_events]\n",
" print(lens)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "48b05b86",
"metadata": {},
"outputs": [],
"source": [
"test = tests()\n",
"test.test1_3_note_info()"
]
},
{
"cell_type": "markdown",
"id": "b4c4c08c",
"metadata": {},
"source": [
"**1.4** Print the record with the most sentences."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9e053c31",
"metadata": {},
"outputs": [],
"source": [
"class tests(tests):\n",
" def test1_4_longest_note(self):\n",
" # Gets longest note among the patient notes queued by get_note_events()\n",
" self.ehrdb.get_note_events()\n",
" pid, rowid, doclen = self.ehrdb.longest_NE()\n",
" print('patient id is:', pid, '\\nNoteEvent id is:', rowid, '\\nlength: ', doclen)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "add51bdb",
"metadata": {},
"outputs": [],
"source": [
"test = tests()\n",
"test.test1_4_longest_note()"
]
},
{
"cell_type": "markdown",
"id": "8bd66e85",
"metadata": {},
"source": [
"### Test 2"
]
},
{
"cell_type": "markdown",
"id": "a4e3b881",
"metadata": {},
"source": [
"**2.1** Display document given a document ID."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe0a7b21",
"metadata": {},
"outputs": [],
"source": [
"class tests(tests):\n",
" def test2_1_print_note(self):\n",
" ### There are 2083180 patient records in NOTEEVENTS. ###\n",
" record_id = random.randint(1, NUM_DOCS + 1)\n",
" kit_rec = self.ehrdb.get_document(record_id)\n",
" print(\"Document with ID %d\\n: \" % record_id, kit_rec)\n",
"\n",
" self.ehrdb.cur.execute(\"select TEXT from mimic.NOTEEVENTS where ROW_ID = %d\" % record_id)\n",
" test_rec = self.ehrdb.cur.fetchall()[0][0]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dfe5c0cc",
"metadata": {},
"outputs": [],
"source": [
"test = tests()\n",
"test.test2_1_print_note()"
]
},
{
"cell_type": "markdown",
"id": "52c4142a",
"metadata": {},
"source": [
"**2.2** Count the number of documents associated with a given patient given patient ID."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f6208b7",
"metadata": {},
"outputs": [],
"source": [
"class tests(tests):\n",
" def test2_2_patient_info(self):\n",
" ### There are records from 46520 unique patients in MIMIC. ###\n",
" patient_id = random.randint(1, NUM_PATIENTS + 1)\n",
" kit_ids = self.ehrdb.get_all_patient_document_ids(patient_id)\n",
" print('Document IDs related to Patient %d: ' % patient_id, kit_ids)\n",
" print(\"Number of docs related to Patient %d: \" % patient_id, len(kit_ids))\n",
"\n",
" self.ehrdb.cur.execute(\"select ROW_ID from mimic.NOTEEVENTS where SUBJECT_ID = %d\" % patient_id)\n",
" raw = self.ehrdb.cur.fetchall()\n",
" test_ids = ehrkit.flatten(raw)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7ae06b45",
"metadata": {},
"outputs": [],
"source": [
"test = tests()\n",
"test.test2_2_patient_info()"
]
},
{
"cell_type": "markdown",
"id": "f571db23",
"metadata": {},
"source": [
"**2.3** List all document IDs."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bab71df9",
"metadata": {},
"outputs": [],
"source": [
"class tests(tests):\n",
" def test2_3_doc_ids(self):\n",
" kit_ids = self.ehrdb.list_all_document_ids()\n",
"\n",
" self.ehrdb.cur.execute(\"select ROW_ID from mimic.NOTEEVENTS\")\n",
" raw = self.ehrdb.cur.fetchall()\n",
" test_ids = ehrkit.flatten(raw)\n",
" print('All document ids: (truncated)')\n",
" print(test_ids[:30])\n",
" print('...')\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "29d8fb2a",
"metadata": {},
"outputs": [],
"source": [
"test = tests()\n",
"test.test2_3_doc_ids()"
]
},
{
"cell_type": "markdown",
"id": "549c2a16",
"metadata": {},
"source": [
"**2.4** List all patient IDs."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c1896914",
"metadata": {},
"outputs": [],
"source": [
"class tests(tests):\n",
" def test2_4_patient_ids(self):\n",
" kit_ids = self.ehrdb.list_all_patient_ids()\n",
"\n",
" self.ehrdb.cur.execute(\"select SUBJECT_ID from mimic.PATIENTS\")\n",
" raw = self.ehrdb.cur.fetchall()\n",
" test_ids = ehrkit.flatten(raw)\n",
" print(\"All patient ids: (truncated)\")\n",
" print(test_ids[:30])\n",
" print('...')\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "43026f9b",
"metadata": {},
"outputs": [],
"source": [
"test = tests()\n",
"test.test2_4_patient_ids()"
]
},
{
"cell_type": "markdown",
"id": "898ddf27",
"metadata": {},
"source": [
"**2.5** List all document IDs for a given admission date."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cf636ec2",
"metadata": {},
"outputs": [],
"source": [
"class tests(tests):\n",
" def test2_5_docs_on_date(self):\n",
" ### Select random date from a date in the database. \n",
" ### Dates are shifted to future but preserve time, weekday, and seasonality.\n",
" random_id = random.randint(1, NUM_DOCS + 1)\n",
" self.ehrdb.cur.execute(\"select CHARTDATE from mimic.NOTEEVENTS where ROW_ID = %d\" % random_id)\n",
" date = self.ehrdb.cur.fetchall()[0][0]\n",
"\n",
" kit_ids = self.ehrdb.get_documents_d(date)\n",
"\n",
" self.ehrdb.cur.execute(\"select ROW_ID from mimic.NOTEEVENTS where CHARTDATE = \\\"%s\\\"\" % date)\n",
" raw = self.ehrdb.cur.fetchall()\n",
" test_ids = ehrkit.flatten(raw)\n",
" print(f\"Selected date: {date}\")\n",
" print(f\"Test ids {test_ids[:30]} ...\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0da163eb",
"metadata": {},
"outputs": [],
"source": [
"test = tests()\n",
"test.test2_5_docs_on_date()"
]
},
{
"cell_type": "markdown",
"id": "378be2d9",
"metadata": {},
"source": [
"### Test 3 ###"
]
},
{
"cell_type": "markdown",
"id": "2b49343e",
"metadata": {},
"source": [
"**3.1** Extract all abbreviations from a document, given the document ID."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "177712a6",
"metadata": {},
"outputs": [],
"source": [
"class tests(tests):\n",
" def test3_1_extract_abbreviations(self):\n",
" # Defines abbreviation as a string of capitalized letters\n",
" random_id = random.randint(1, NUM_DOCS + 1)\n",
" print(\"Collecting abbreviations for document %d...\" % random_id)\n",
" kit_abbs = self.ehrdb.get_abbreviations(random_id)\n",
"\n",
" sents = self.ehrdb.get_document_sents(random_id)\n",
" test_abbs = set()\n",
" for sent in sents:\n",
" for word in ehrkit.word_tokenize(sent):\n",
" print(word)\n",
" pattern = r'[A-Z]{2}' # Only selects words in ALL CAPS\n",
" if re.match(pattern, word):\n",
" test_abbs.add(word)\n",
"\n",
" print(kit_abbs)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "319d8a75",
"metadata": {},
"outputs": [],
"source": [
"test = tests()\n",
"test.test3_1_extract_abbreviations()"
]
},
{
"cell_type": "markdown",
"id": "27f0fce0",
"metadata": {},
"source": [
"**3.2** List all document IDs that include keywork \"meningitis\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3620b640",
"metadata": {},
"outputs": [],
"source": [
"class tests(tests):\n",
" def test3_2_docs_with_query(self):\n",
" query = \"meningitis\"\n",
" print('Printing a list of all document ids including query like ', query)\n",
" kit_ids = self.ehrdb.get_documents_q(query)\n",
" print(kit_ids[:30]) # Extremely long list of DOC_IDs\n",
" print(\"...\")\n",
"\n",
" query = \"%\"+query+\"%\"\n",
" self.ehrdb.cur.execute(\"select ROW_ID from mimic.NOTEEVENTS where TEXT like \\'%s\\'\" % query)\n",
" raw = self.ehrdb.cur.fetchall()\n",
" test_ids = ehrkit.flatten(raw)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "07163b53",
"metadata": {},
"outputs": [],
"source": [
"test = tests()\n",
"test.test3_2_docs_with_query()"
]
},
{
"cell_type": "markdown",
"id": "c4deaa30",
"metadata": {},
"source": [
"**3.3** List all document IDs that include keywords \"Service: SURGERY\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "570b1ce7",
"metadata": {},
"outputs": [],
"source": [
"class tests(tests):\n",
" def test3_3_query_docs(self):\n",
" ### Task 3.3 is the same as task 3.2 with a different query. ###\n",
" query = \"Service: SURGERY\"\n",
" print('Printing a list of all document ids including query like ', query)\n",
" kit_ids = self.ehrdb.get_documents_q(query)\n",
" print(kit_ids[:30]) # Extremely long list of DOC_IDs\n",
" print(\"...\")\n",
"\n",
" query = \"%\"+query+\"%\"\n",
" self.ehrdb.cur.execute(\"select ROW_ID from mimic.NOTEEVENTS where TEXT like \\'%s\\'\" % query)\n",
" raw = self.ehrdb.cur.fetchall()\n",
" test_ids = ehrkit.flatten(raw)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2923f60a",
"metadata": {},
"outputs": [],
"source": [
"test = tests()\n",
"test.test3_3_query_docs()"
]
},
{
"cell_type": "markdown",
"id": "b073b5b6",
"metadata": {},
"source": [
"**3.4** Given a document ID, show a numbered list of all sentences in that document."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e4eb5dc",
"metadata": {},
"outputs": [],
"source": [
"class tests(tests):\n",
" def test3_4_doc_sentences(self):\n",
" doc_id = random.randint(1, NUM_DOCS + 1)\n",
" print('Kit function printing a numbered list of all sentences in document %d' % doc_id)\n",
" # MIMIC EHRs are very messy and sentence tokenizaton often doesn't work\n",
" kit_doc = self.ehrdb.get_document_sents(doc_id)\n",
" ehrkit.numbered_print(kit_doc)\n",
"\n",
" self.ehrdb.cur.execute(\"select TEXT from mimic.NOTEEVENTS where ROW_ID = %d \" % doc_id)\n",
" raw = self.ehrdb.cur.fetchall()\n",
" test_doc = ehrkit.sent_tokenize(raw[0][0])\n",
" print(test_doc)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d09aeace",
"metadata": {},
"outputs": [],
"source": [
"test = tests()\n",
"test.test3_4_doc_sentences()"
]
},
{
"cell_type": "markdown",
"id": "8bba4817",
"metadata": {},
"source": [
"**3.5** Count the number of prescriptions for each unique medication."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2f87aaac",
"metadata": {},
"outputs": [],
"source": [
"class tests(tests):\n",
" def test3_7_medications(self):\n",
" kit_meds = self.ehrdb.count_all_prescriptions()\n",
"\n",
" test_meds = {}\n",
" self.ehrdb.cur.execute(\"select DRUG from mimic.PRESCRIPTIONS\")\n",
" raw = self.ehrdb.cur.fetchall()\n",
" meds_list = ehrkit.flatten(raw)\n",
" for med in meds_list:\n",
" if med in test_meds:\n",
" test_meds[med] += 1\n",
" else:\n",
" test_meds[med] = 1\n",
"\n",
" print(meds_list[:30])\n",
" print(\"...\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5464edb5",
"metadata": {},
"outputs": [],
"source": [
"test = tests()\n",
"test.test3_7_medications()"
]
},
{
"cell_type": "markdown",
"id": "0ee9f298",
"metadata": {},
"source": [
"### Test T5 ###"
]
},
{
"cell_type": "markdown",
"id": "99a030b2",
"metadata": {},
"source": [
"**5.4** Count how many patients are labeled as \"male\" or \"female\"."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1a05ce6b",
"metadata": {},
"outputs": [],
"source": [
"class tests(tests):\n",
" def test5_4_count_gender(self):\n",
" gender = random.choice(['M', 'F'])\n",
" kit_count = self.ehrdb.count_gender(gender)\n",
"\n",
" self.ehrdb.cur.execute('SELECT COUNT(*) FROM mimic.PATIENTS WHERE GENDER = \\'%s\\'' % gender)\n",
" raw = self.ehrdb.cur.fetchall()\n",
" test_count = raw[0][0]\n",
" print('Gender:', gender, '\\tCount:', str(test_count))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7a1459fe",
"metadata": {},
"outputs": [],
"source": [
"test = tests()\n",
"test.test5_4_count_gender()"
]
},
{
"cell_type": "markdown",
"id": "5bf4ca7a",
"metadata": {},
"source": [
"### Test T7 ###"
]
},
{
"cell_type": "markdown",
"id": "936f87bc",
"metadata": {},
"source": [
"**7.1** Creates extractive summary of an EHR with Naive Bayes Algorithm trained on PubMed articles."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "40764971",
"metadata": {},
"outputs": [],
"source": [
"class tests(tests):\n",
" def test7_1_naive_bayes(self):\n",
" from pubmed_naive_bayes import classify_nb\n",
" from get_pubmed_nb_data import build_vecs\n",
" from sklearn.naive_bayes import GaussianNB\n",
"\n",
" doc_id, text = select_ehr(self.ehrdb)\n",
" body_type = input('Use Naive Bayes model trained from whole body sections or just their body introductions?\\n\\t'\\\n",
" '[w=whole body, j=just intro, DEFAULT=just intro]: ')\n",
" if body_type == 'w':\n",
" ending = 'body'\n",
" elif body_type in ['j', '']:\n",
" ending = 'intro'\n",
" else:\n",
" sys.exit('Error: Must input \\'w\\' or \\'j.\\'')\n",
" SUMM_DIR = os.path.abspath(os.path.join(os.path.dirname('EHRKit'), '..', 'summarization', 'pubmed_summarization'))\n",
" best_dir_name = get_nb_dir(ending, SUMM_DIR)\n",
" if not best_dir_name:\n",
" message = 'No Naive Bayes models of this type have been fit. '\\\n",
" 'Would you like to do so now?\\n\\t[DEFAULT=Yes] '\n",
" response = input(message)\n",
" if response.lower() in ['y', 'yes', '']:\n",
" command = 'python ' + os.path.abspath(os.path.join(os.path.dirname('EHRKit'), '..', 'summarization', 'pubmed_summarization', 'pubmed_naive_bayes.py'))\n",
" os.system(command)\n",
" best_dir_name = get_nb_dir(ending)\n",
" if response.lower() not in ['y', 'yes', ''] or not best_dir_name:\n",
" sys.exit('Exiting.')\n",
"\n",
" # Fits model to data \n",
" NB_DIR = os.path.join(SUMM_DIR, best_dir_name, 'nb')\n",
" with open(os.path.join(NB_DIR, 'feature_vecs.json'), 'r') as f:\n",
" data = json.load(f)\n",
" xtrain, ytrain = data['train_features'], data['train_outputs']\n",
" gnb = GaussianNB()\n",
" gnb.fit(xtrain, ytrain)\n",
"\n",
" # Evaluates on model\n",
" tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')\n",
" feature_vecs, _ = build_vecs(text, None, tokenizer)\n",
" PCT_SUM = 0.3\n",
" preds = classify_nb(feature_vecs, PCT_SUM, gnb)\n",
" sents = tokenizer.tokenize(text)\n",
" summary = ''\n",
" for i in range(len(preds)):\n",
" if preds[i] == 1:\n",
" summary += sents[i]\n",
"\n",
" show_summary(doc_id, text, summary, 'Naive Bayes')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b8e05bc5",
"metadata": {},
"outputs": [],
"source": [
"test = tests()\n",
"test.test7_1_naive_bayes()"
]
},
{
"cell_type": "markdown",
"id": "fa265304",
"metadata": {},
"source": [
"**7.2** Generates abstractive summary of an EHR with pre-trained Distilbart model from Huggingface."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "992ce1f2",
"metadata": {},
"outputs": [],
"source": [
"class tests(tests):\n",
" def test7_2_distilbart_summary(self):\n",
" # Distilbart for summarization. Trained on CNN/ Daily Mail (~4x longer summaries than XSum)\n",
" doc_id, text = select_ehr(self.ehrdb, requires_long=True)\n",
" model_name = 'sshleifer/distilbart-cnn-12-6'\n",
" summary = self.ehrdb.summarize_huggingface(text, model_name)\n",
"\n",
" show_summary(doc_id, text, summary, model_name)\n",
" print('Number of Words in Full EHR: %d' % len(text.split()))\n",
" print('Number of Words in %s Summary: %d' % (model_name, len(summary.split())))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3bd82bbd",
"metadata": {},
"outputs": [],
"source": [
"test = tests()\n",
"test.test7_2_distilbart_summary()"
]
},
{
"cell_type": "markdown",
"id": "1f678ea5",
"metadata": {},
"source": [
"**7.3** Generates abstractive summary of an EHR with pre-trained T5 model from Huggingface."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "27653251",
"metadata": {},
"outputs": [],
"source": [
"class tests(tests):\n",
" def test7_3_t5_summary(self):\n",
" # T5 for summarization. Trained on CNN/ Daily Mail\n",
" doc_id, text = select_ehr(self.ehrdb, requires_long=True)\n",
" model_name = 't5-small'\n",
" summary = self.ehrdb.summarize_huggingface(text, model_name)\n",
"\n",
" show_summary(doc_id, text, summary, model_name)\n",
" print('Number of Words in Full EHR: %d' % len(text.split()))\n",
" print('Number of Words in %s Summary: %d' % (model_name, len(summary.split())))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e60e48d7",
"metadata": {},
"outputs": [],
"source": [
"test = tests()\n",
"test.test7_3_t5_summary()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}