Switch to unified view

a b/tests/test_functions_no_output.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "markdown",
5
   "id": "accbacfc",
6
   "metadata": {},
7
   "source": [
8
    "## EHRKit Demo\n",
9
    "\n",
10
    "In this notebook, we demonstrate some of the basic functions that you can utilize from the EHRKit."
11
   ]
12
  },
13
  {
14
   "cell_type": "markdown",
15
   "id": "9364f9f4",
16
   "metadata": {},
17
   "source": [
18
    "### Preparation"
19
   ]
20
  },
21
  {
22
   "cell_type": "code",
23
   "execution_count": null,
24
   "id": "6b745701",
25
   "metadata": {},
26
   "outputs": [],
27
   "source": [
28
    "import unittest\n",
29
    "import random\n",
30
    "import sys, os\n",
31
    "import re\n",
32
    "import nltk\n",
33
    "import json"
34
   ]
35
  },
36
  {
37
   "cell_type": "code",
38
   "execution_count": null,
39
   "id": "dbcf7369",
40
   "metadata": {},
41
   "outputs": [],
42
   "source": [
43
    "sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath('EHRKit'))))\n",
44
    "sys.path.append(os.path.abspath(os.path.join(os.path.dirname(''), '..', 'allennlp')))\n",
45
    "sys.path.append(os.path.abspath(os.path.join(os.path.dirname(''), '..', 'summarization', 'pubmed_summarization')))"
46
   ]
47
  },
48
  {
49
   "cell_type": "code",
50
   "execution_count": null,
51
   "id": "ccaf4e2d",
52
   "metadata": {},
53
   "outputs": [],
54
   "source": [
55
    "from ehrkit import ehrkit"
56
   ]
57
  },
58
  {
59
   "cell_type": "code",
60
   "execution_count": null,
61
   "id": "13222d2c",
62
   "metadata": {},
63
   "outputs": [],
64
   "source": [
65
    "from getpass import getpass"
66
   ]
67
  },
68
  {
69
   "cell_type": "code",
70
   "execution_count": null,
71
   "id": "722a048e",
72
   "metadata": {},
73
   "outputs": [],
74
   "source": [
75
    "try: \n",
76
    "    from config import USERNAME, PASSWORD\n",
77
    "except:\n",
78
    "    print(\"Please put your username and password in config.py\")\n",
79
    "    USERNAME = input('DB_username?')\n",
80
    "    PASSWORD = getpass('DB_password?')"
81
   ]
82
  },
83
  {
84
   "cell_type": "code",
85
   "execution_count": null,
86
   "id": "827f0ad3",
87
   "metadata": {},
88
   "outputs": [],
89
   "source": [
90
    "DOC_ID = 1354526 # Temporary!!!\n",
91
    "\n",
92
    "# Number of documents in NOTEEVENTS.\n",
93
    "NUM_DOCS = 2083180\n",
94
    "\n",
95
    "# Number of patients in PATIENTS.\n",
96
    "NUM_PATIENTS = 46520\n",
97
    "\n",
98
    "# Number of diagnoses in DIAGNOSES_ICD.\n",
99
    "NUM_DIAGS = 823933"
100
   ]
101
  },
102
  {
103
   "cell_type": "code",
104
   "execution_count": null,
105
   "id": "1654c5cb",
106
   "metadata": {},
107
   "outputs": [],
108
   "source": [
109
    "def select_ehr(ehrdb, requires_long=False, recursing=False):\n",
110
    "    if recursing:\n",
111
    "        doc_id = ''\n",
112
    "    else:\n",
113
    "        doc_id = input(\"MIMIC Document ID [press Enter for random]: \")\n",
114
    "    if doc_id == '':\n",
115
    "        # Picks random document\n",
116
    "        ehrdb.cur.execute(\"SELECT ROW_ID FROM mimic.NOTEEVENTS ORDER BY RAND() LIMIT 1\")\n",
117
    "        doc_id = ehrdb.cur.fetchall()[0][0]\n",
118
    "        text = ehrdb.get_document(int(doc_id))\n",
119
    "        if len(text.split()) > 200 or not requires_long:\n",
120
    "            return doc_id, text\n",
121
    "        else:\n",
122
    "            return select_ehr(ehrdb, requires_long, True)\n",
123
    "    else:\n",
124
    "        # Get inputted document\n",
125
    "        try:\n",
126
    "            text = ehrdb.get_document(int(doc_id))\n",
127
    "            return doc_id, text\n",
128
    "        except:\n",
129
    "            message = 'Error: There is no document with ID \\'' + doc_id + '\\' in mimic.NOTEEVENTS'\n",
130
    "            sys.exit(message)"
131
   ]
132
  },
133
  {
134
   "cell_type": "code",
135
   "execution_count": null,
136
   "id": "43f13fc3",
137
   "metadata": {},
138
   "outputs": [],
139
   "source": [
140
    "def get_nb_dir(ending, SUMM_DIR):\n",
141
    "    # Gets path of Naive Bayes model trained on most examples\n",
142
    "    dir_nums = []\n",
143
    "    for dir in os.listdir(SUMM_DIR):\n",
144
    "        if os.path.isdir(os.path.join(SUMM_DIR, dir)) and dir.endswith('_exs_' + ending):\n",
145
    "            if os.path.exists(os.path.join(SUMM_DIR, dir, 'nb')):  \n",
146
    "                try:\n",
147
    "                    dir_nums.append(int(dir.split('_')[0]))\n",
148
    "                except:\n",
149
    "                    continue\n",
150
    "    if len(dir_nums) > 0:\n",
151
    "        best_dir_name = str(max(dir_nums)) + '_exs_' + ending\n",
152
    "        return best_dir_name\n",
153
    "    else:\n",
154
    "        return None"
155
   ]
156
  },
157
  {
158
   "cell_type": "code",
159
   "execution_count": null,
160
   "id": "272882d8",
161
   "metadata": {},
162
   "outputs": [],
163
   "source": [
164
    "def show_summary(doc_id, text, summary, model_name):\n",
165
    "    x = input('Show full EHR (DOC ID %s)? [DEFAULT=Yes]' % doc_id)\n",
166
    "    # x = ''\n",
167
    "    if x.lower() in ['y', 'yes', '']:\n",
168
    "        print('\\n\\n' + '-'*30 + 'Full EHR' + '-'*30)\n",
169
    "        print(text + '\\n')\n",
170
    "        print('-'*80 + '\\n\\n')\n",
171
    "\n",
172
    "    print('-'*30 + 'Predicted Summary ' + model_name + '-'*30)\n",
173
    "    print(summary)\n",
174
    "    print('-'*80 + '\\n\\n')"
175
   ]
176
  },
177
  {
178
   "cell_type": "code",
179
   "execution_count": null,
180
   "id": "0bcc571c",
181
   "metadata": {},
182
   "outputs": [],
183
   "source": [
184
    "class tests():\n",
185
    "    def __init__(self):\n",
186
    "        self.ehrdb = ehrkit.start_session(USERNAME, PASSWORD)\n",
187
    "        self.ehrdb.get_patients(3)\n",
188
    "    "
189
   ]
190
  },
191
  {
192
   "cell_type": "markdown",
193
   "id": "5e665759",
194
   "metadata": {},
195
   "source": [
196
    "**1.1** Count the total number of patients."
197
   ]
198
  },
199
  {
200
   "cell_type": "markdown",
201
   "id": "578dab47",
202
   "metadata": {},
203
   "source": [
204
    "### Test T1"
205
   ]
206
  },
207
  {
208
   "cell_type": "code",
209
   "execution_count": null,
210
   "id": "7ac56588",
211
   "metadata": {},
212
   "outputs": [],
213
   "source": [
214
    "class tests(tests):\n",
215
    "    def test1_1_count_patients(self):\n",
216
    "            kit_count = self.ehrdb.count_patients()\n",
217
    "            print(\"Patient count: \", kit_count)\n",
218
    "\n",
219
    "            self.ehrdb.cur.execute(\"SELECT COUNT(*) FROM mimic.PATIENTS\")\n",
220
    "            raw = self.ehrdb.cur.fetchall()\n",
221
    "            test_count = int(raw[0][0])"
222
   ]
223
  },
224
  {
225
   "cell_type": "code",
226
   "execution_count": null,
227
   "id": "709a054f",
228
   "metadata": {},
229
   "outputs": [],
230
   "source": [
231
    "test = tests()\n",
232
    "test.test1_1_count_patients()"
233
   ]
234
  },
235
  {
236
   "cell_type": "markdown",
237
   "id": "f04dc8bc",
238
   "metadata": {},
239
   "source": [
240
    "**1.3** Count the total number of sentences."
241
   ]
242
  },
243
  {
244
   "cell_type": "code",
245
   "execution_count": null,
246
   "id": "6a77086e",
247
   "metadata": {},
248
   "outputs": [],
249
   "source": [
250
    "class tests(tests):\n",
251
    "    def test1_3_note_info(self):\n",
252
    "        self.ehrdb.get_note_events()\n",
253
    "        print('output format: SUBJECT_ID, ROW_ID, NoteEvent length')\n",
254
    "        lens = [(patient.id, note[0], len(note[1])) for patient in self.ehrdb.patients.values() for note in patient.note_events]\n",
255
    "        print(lens)\n",
256
    "        "
257
   ]
258
  },
259
  {
260
   "cell_type": "code",
261
   "execution_count": null,
262
   "id": "48b05b86",
263
   "metadata": {},
264
   "outputs": [],
265
   "source": [
266
    "test = tests()\n",
267
    "test.test1_3_note_info()"
268
   ]
269
  },
270
  {
271
   "cell_type": "markdown",
272
   "id": "b4c4c08c",
273
   "metadata": {},
274
   "source": [
275
    "**1.4** Print the record with the most sentences."
276
   ]
277
  },
278
  {
279
   "cell_type": "code",
280
   "execution_count": null,
281
   "id": "9e053c31",
282
   "metadata": {},
283
   "outputs": [],
284
   "source": [
285
    "class tests(tests):\n",
286
    "    def test1_4_longest_note(self):\n",
287
    "            # Gets longest note among the patient notes queued by get_note_events()\n",
288
    "            self.ehrdb.get_note_events()\n",
289
    "            pid, rowid, doclen = self.ehrdb.longest_NE()\n",
290
    "            print('patient id is:', pid, '\\nNoteEvent id is:', rowid, '\\nlength: ', doclen)\n"
291
   ]
292
  },
293
  {
294
   "cell_type": "code",
295
   "execution_count": null,
296
   "id": "add51bdb",
297
   "metadata": {},
298
   "outputs": [],
299
   "source": [
300
    "test = tests()\n",
301
    "test.test1_4_longest_note()"
302
   ]
303
  },
304
  {
305
   "cell_type": "markdown",
306
   "id": "8bd66e85",
307
   "metadata": {},
308
   "source": [
309
    "### Test 2"
310
   ]
311
  },
312
  {
313
   "cell_type": "markdown",
314
   "id": "a4e3b881",
315
   "metadata": {},
316
   "source": [
317
    "**2.1** Display document given a document ID."
318
   ]
319
  },
320
  {
321
   "cell_type": "code",
322
   "execution_count": null,
323
   "id": "fe0a7b21",
324
   "metadata": {},
325
   "outputs": [],
326
   "source": [
327
    "class tests(tests):\n",
328
    "    def test2_1_print_note(self):\n",
329
    "            ### There are 2083180 patient records in NOTEEVENTS. ###\n",
330
    "            record_id = random.randint(1, NUM_DOCS + 1)\n",
331
    "            kit_rec = self.ehrdb.get_document(record_id)\n",
332
    "            print(\"Document with ID %d\\n: \" % record_id, kit_rec)\n",
333
    "\n",
334
    "            self.ehrdb.cur.execute(\"select TEXT from mimic.NOTEEVENTS where ROW_ID = %d\" % record_id)\n",
335
    "            test_rec = self.ehrdb.cur.fetchall()[0][0]\n"
336
   ]
337
  },
338
  {
339
   "cell_type": "code",
340
   "execution_count": null,
341
   "id": "dfe5c0cc",
342
   "metadata": {},
343
   "outputs": [],
344
   "source": [
345
    "test = tests()\n",
346
    "test.test2_1_print_note()"
347
   ]
348
  },
349
  {
350
   "cell_type": "markdown",
351
   "id": "52c4142a",
352
   "metadata": {},
353
   "source": [
354
    "**2.2** Count the number of documents associated with a given patient given patient ID."
355
   ]
356
  },
357
  {
358
   "cell_type": "code",
359
   "execution_count": null,
360
   "id": "8f6208b7",
361
   "metadata": {},
362
   "outputs": [],
363
   "source": [
364
    "class tests(tests):\n",
365
    "    def test2_2_patient_info(self):\n",
366
    "        ### There are records from 46520 unique patients in MIMIC. ###\n",
367
    "        patient_id = random.randint(1, NUM_PATIENTS + 1)\n",
368
    "        kit_ids = self.ehrdb.get_all_patient_document_ids(patient_id)\n",
369
    "        print('Document IDs related to Patient %d: ' % patient_id, kit_ids)\n",
370
    "        print(\"Number of docs related to Patient %d: \" % patient_id, len(kit_ids))\n",
371
    "\n",
372
    "        self.ehrdb.cur.execute(\"select ROW_ID from mimic.NOTEEVENTS where SUBJECT_ID = %d\" % patient_id)\n",
373
    "        raw = self.ehrdb.cur.fetchall()\n",
374
    "        test_ids = ehrkit.flatten(raw)"
375
   ]
376
  },
377
  {
378
   "cell_type": "code",
379
   "execution_count": null,
380
   "id": "7ae06b45",
381
   "metadata": {},
382
   "outputs": [],
383
   "source": [
384
    "test = tests()\n",
385
    "test.test2_2_patient_info()"
386
   ]
387
  },
388
  {
389
   "cell_type": "markdown",
390
   "id": "f571db23",
391
   "metadata": {},
392
   "source": [
393
    "**2.3** List all document IDs."
394
   ]
395
  },
396
  {
397
   "cell_type": "code",
398
   "execution_count": null,
399
   "id": "bab71df9",
400
   "metadata": {},
401
   "outputs": [],
402
   "source": [
403
    "class tests(tests):\n",
404
    "    def test2_3_doc_ids(self):\n",
405
    "            kit_ids = self.ehrdb.list_all_document_ids()\n",
406
    "\n",
407
    "            self.ehrdb.cur.execute(\"select ROW_ID from mimic.NOTEEVENTS\")\n",
408
    "            raw = self.ehrdb.cur.fetchall()\n",
409
    "            test_ids = ehrkit.flatten(raw)\n",
410
    "            print('All document ids: (truncated)')\n",
411
    "            print(test_ids[:30])\n",
412
    "            print('...')\n"
413
   ]
414
  },
415
  {
416
   "cell_type": "code",
417
   "execution_count": null,
418
   "id": "29d8fb2a",
419
   "metadata": {},
420
   "outputs": [],
421
   "source": [
422
    "test = tests()\n",
423
    "test.test2_3_doc_ids()"
424
   ]
425
  },
426
  {
427
   "cell_type": "markdown",
428
   "id": "549c2a16",
429
   "metadata": {},
430
   "source": [
431
    "**2.4** List all patient IDs."
432
   ]
433
  },
434
  {
435
   "cell_type": "code",
436
   "execution_count": null,
437
   "id": "c1896914",
438
   "metadata": {},
439
   "outputs": [],
440
   "source": [
441
    "class tests(tests):\n",
442
    "    def test2_4_patient_ids(self):\n",
443
    "            kit_ids = self.ehrdb.list_all_patient_ids()\n",
444
    "\n",
445
    "            self.ehrdb.cur.execute(\"select SUBJECT_ID from mimic.PATIENTS\")\n",
446
    "            raw = self.ehrdb.cur.fetchall()\n",
447
    "            test_ids = ehrkit.flatten(raw)\n",
448
    "            print(\"All patient ids: (truncated)\")\n",
449
    "            print(test_ids[:30])\n",
450
    "            print('...')\n"
451
   ]
452
  },
453
  {
454
   "cell_type": "code",
455
   "execution_count": null,
456
   "id": "43026f9b",
457
   "metadata": {},
458
   "outputs": [],
459
   "source": [
460
    "test = tests()\n",
461
    "test.test2_4_patient_ids()"
462
   ]
463
  },
464
  {
465
   "cell_type": "markdown",
466
   "id": "898ddf27",
467
   "metadata": {},
468
   "source": [
469
    "**2.5** List all document IDs for a given admission date."
470
   ]
471
  },
472
  {
473
   "cell_type": "code",
474
   "execution_count": null,
475
   "id": "cf636ec2",
476
   "metadata": {},
477
   "outputs": [],
478
   "source": [
479
    "class tests(tests):\n",
480
    "    def test2_5_docs_on_date(self):\n",
481
    "        ### Select random date from a date in the database. \n",
482
    "        ### Dates are shifted to future but preserve time, weekday, and seasonality.\n",
483
    "        random_id = random.randint(1, NUM_DOCS + 1)\n",
484
    "        self.ehrdb.cur.execute(\"select CHARTDATE from mimic.NOTEEVENTS where ROW_ID = %d\" % random_id)\n",
485
    "        date = self.ehrdb.cur.fetchall()[0][0]\n",
486
    "\n",
487
    "        kit_ids = self.ehrdb.get_documents_d(date)\n",
488
    "\n",
489
    "        self.ehrdb.cur.execute(\"select ROW_ID from mimic.NOTEEVENTS where CHARTDATE = \\\"%s\\\"\" % date)\n",
490
    "        raw = self.ehrdb.cur.fetchall()\n",
491
    "        test_ids = ehrkit.flatten(raw)\n",
492
    "        print(f\"Selected date: {date}\")\n",
493
    "        print(f\"Test ids {test_ids[:30]} ...\")"
494
   ]
495
  },
496
  {
497
   "cell_type": "code",
498
   "execution_count": null,
499
   "id": "0da163eb",
500
   "metadata": {},
501
   "outputs": [],
502
   "source": [
503
    "test = tests()\n",
504
    "test.test2_5_docs_on_date()"
505
   ]
506
  },
507
  {
508
   "cell_type": "markdown",
509
   "id": "378be2d9",
510
   "metadata": {},
511
   "source": [
512
    "### Test 3 ###"
513
   ]
514
  },
515
  {
516
   "cell_type": "markdown",
517
   "id": "2b49343e",
518
   "metadata": {},
519
   "source": [
520
    "**3.1** Extract all abbreviations from a document, given the document ID."
521
   ]
522
  },
523
  {
524
   "cell_type": "code",
525
   "execution_count": null,
526
   "id": "177712a6",
527
   "metadata": {},
528
   "outputs": [],
529
   "source": [
530
    "class tests(tests):\n",
531
    "    def test3_1_extract_abbreviations(self):\n",
532
    "        # Defines abbreviation as a string of capitalized letters\n",
533
    "        random_id = random.randint(1, NUM_DOCS + 1)\n",
534
    "        print(\"Collecting abbreviations for document %d...\" % random_id)\n",
535
    "        kit_abbs = self.ehrdb.get_abbreviations(random_id)\n",
536
    "\n",
537
    "        sents = self.ehrdb.get_document_sents(random_id)\n",
538
    "        test_abbs = set()\n",
539
    "        for sent in sents:\n",
540
    "            for word in ehrkit.word_tokenize(sent):\n",
541
    "                print(word)\n",
542
    "                pattern = r'[A-Z]{2}'  # Only selects words in ALL CAPS\n",
543
    "                if re.match(pattern, word):\n",
544
    "                    test_abbs.add(word)\n",
545
    "\n",
546
    "        print(kit_abbs)\n"
547
   ]
548
  },
549
  {
550
   "cell_type": "code",
551
   "execution_count": null,
552
   "id": "319d8a75",
553
   "metadata": {},
554
   "outputs": [],
555
   "source": [
556
    "test = tests()\n",
557
    "test.test3_1_extract_abbreviations()"
558
   ]
559
  },
560
  {
561
   "cell_type": "markdown",
562
   "id": "27f0fce0",
563
   "metadata": {},
564
   "source": [
565
    "**3.2** List all document IDs that include keywork \"meningitis\""
566
   ]
567
  },
568
  {
569
   "cell_type": "code",
570
   "execution_count": null,
571
   "id": "3620b640",
572
   "metadata": {},
573
   "outputs": [],
574
   "source": [
575
    "class tests(tests):\n",
576
    "    def test3_2_docs_with_query(self):\n",
577
    "        query = \"meningitis\"\n",
578
    "        print('Printing a list of all document ids including query like ', query)\n",
579
    "        kit_ids = self.ehrdb.get_documents_q(query)\n",
580
    "        print(kit_ids[:30])  # Extremely long list of DOC_IDs\n",
581
    "        print(\"...\")\n",
582
    "\n",
583
    "        query = \"%\"+query+\"%\"\n",
584
    "        self.ehrdb.cur.execute(\"select ROW_ID from mimic.NOTEEVENTS where TEXT like \\'%s\\'\" % query)\n",
585
    "        raw = self.ehrdb.cur.fetchall()\n",
586
    "        test_ids = ehrkit.flatten(raw)\n"
587
   ]
588
  },
589
  {
590
   "cell_type": "code",
591
   "execution_count": null,
592
   "id": "07163b53",
593
   "metadata": {},
594
   "outputs": [],
595
   "source": [
596
    "test = tests()\n",
597
    "test.test3_2_docs_with_query()"
598
   ]
599
  },
600
  {
601
   "cell_type": "markdown",
602
   "id": "c4deaa30",
603
   "metadata": {},
604
   "source": [
605
    "**3.3** List all document IDs that include keywords \"Service: SURGERY\""
606
   ]
607
  },
608
  {
609
   "cell_type": "code",
610
   "execution_count": null,
611
   "id": "570b1ce7",
612
   "metadata": {},
613
   "outputs": [],
614
   "source": [
615
    "class tests(tests):\n",
616
    "    def test3_3_query_docs(self):\n",
617
    "            ### Task 3.3 is the same as task 3.2 with a different query. ###\n",
618
    "            query = \"Service: SURGERY\"\n",
619
    "            print('Printing a list of all document ids including query like ', query)\n",
620
    "            kit_ids = self.ehrdb.get_documents_q(query)\n",
621
    "            print(kit_ids[:30])  # Extremely long list of DOC_IDs\n",
622
    "            print(\"...\")\n",
623
    "\n",
624
    "            query = \"%\"+query+\"%\"\n",
625
    "            self.ehrdb.cur.execute(\"select ROW_ID from mimic.NOTEEVENTS where TEXT like \\'%s\\'\" % query)\n",
626
    "            raw = self.ehrdb.cur.fetchall()\n",
627
    "            test_ids = ehrkit.flatten(raw)\n"
628
   ]
629
  },
630
  {
631
   "cell_type": "code",
632
   "execution_count": null,
633
   "id": "2923f60a",
634
   "metadata": {},
635
   "outputs": [],
636
   "source": [
637
    "test = tests()\n",
638
    "test.test3_3_query_docs()"
639
   ]
640
  },
641
  {
642
   "cell_type": "markdown",
643
   "id": "b073b5b6",
644
   "metadata": {},
645
   "source": [
646
    "**3.4** Given a document ID, show a numbered list of all sentences in that document."
647
   ]
648
  },
649
  {
650
   "cell_type": "code",
651
   "execution_count": null,
652
   "id": "2e4eb5dc",
653
   "metadata": {},
654
   "outputs": [],
655
   "source": [
656
    "class tests(tests):\n",
657
    "    def test3_4_doc_sentences(self):\n",
658
    "        doc_id = random.randint(1, NUM_DOCS + 1)\n",
659
    "        print('Kit function printing a numbered list of all sentences in document %d' % doc_id)\n",
660
    "        # MIMIC EHRs are very messy and sentence tokenizaton often doesn't work\n",
661
    "        kit_doc = self.ehrdb.get_document_sents(doc_id)\n",
662
    "        ehrkit.numbered_print(kit_doc)\n",
663
    "\n",
664
    "        self.ehrdb.cur.execute(\"select TEXT from mimic.NOTEEVENTS where ROW_ID = %d \" % doc_id)\n",
665
    "        raw = self.ehrdb.cur.fetchall()\n",
666
    "        test_doc = ehrkit.sent_tokenize(raw[0][0])\n",
667
    "        print(test_doc)\n"
668
   ]
669
  },
670
  {
671
   "cell_type": "code",
672
   "execution_count": null,
673
   "id": "d09aeace",
674
   "metadata": {},
675
   "outputs": [],
676
   "source": [
677
    "test = tests()\n",
678
    "test.test3_4_doc_sentences()"
679
   ]
680
  },
681
  {
682
   "cell_type": "markdown",
683
   "id": "8bba4817",
684
   "metadata": {},
685
   "source": [
686
    "**3.5** Count the number of prescriptions for each unique medication."
687
   ]
688
  },
689
  {
690
   "cell_type": "code",
691
   "execution_count": null,
692
   "id": "2f87aaac",
693
   "metadata": {},
694
   "outputs": [],
695
   "source": [
696
    "class tests(tests):\n",
697
    "    def test3_7_medications(self):\n",
698
    "            kit_meds = self.ehrdb.count_all_prescriptions()\n",
699
    "\n",
700
    "            test_meds = {}\n",
701
    "            self.ehrdb.cur.execute(\"select DRUG from mimic.PRESCRIPTIONS\")\n",
702
    "            raw = self.ehrdb.cur.fetchall()\n",
703
    "            meds_list = ehrkit.flatten(raw)\n",
704
    "            for med in meds_list:\n",
705
    "                if med in test_meds:\n",
706
    "                    test_meds[med] += 1\n",
707
    "                else:\n",
708
    "                    test_meds[med] = 1\n",
709
    "\n",
710
    "            print(meds_list[:30])\n",
711
    "            print(\"...\")\n"
712
   ]
713
  },
714
  {
715
   "cell_type": "code",
716
   "execution_count": null,
717
   "id": "5464edb5",
718
   "metadata": {},
719
   "outputs": [],
720
   "source": [
721
    "test = tests()\n",
722
    "test.test3_7_medications()"
723
   ]
724
  },
725
  {
726
   "cell_type": "markdown",
727
   "id": "0ee9f298",
728
   "metadata": {},
729
   "source": [
730
    "### Test T5 ###"
731
   ]
732
  },
733
  {
734
   "cell_type": "markdown",
735
   "id": "99a030b2",
736
   "metadata": {},
737
   "source": [
738
    "**5.4** Count how many patients are labeled as \"male\" or \"female\"."
739
   ]
740
  },
741
  {
742
   "cell_type": "code",
743
   "execution_count": null,
744
   "id": "1a05ce6b",
745
   "metadata": {},
746
   "outputs": [],
747
   "source": [
748
    "class tests(tests):\n",
749
    "    def test5_4_count_gender(self):\n",
750
    "        gender = random.choice(['M', 'F'])\n",
751
    "        kit_count = self.ehrdb.count_gender(gender)\n",
752
    "\n",
753
    "        self.ehrdb.cur.execute('SELECT COUNT(*) FROM mimic.PATIENTS WHERE GENDER = \\'%s\\'' % gender)\n",
754
    "        raw = self.ehrdb.cur.fetchall()\n",
755
    "        test_count = raw[0][0]\n",
756
    "        print('Gender:', gender, '\\tCount:', str(test_count))\n"
757
   ]
758
  },
759
  {
760
   "cell_type": "code",
761
   "execution_count": null,
762
   "id": "7a1459fe",
763
   "metadata": {},
764
   "outputs": [],
765
   "source": [
766
    "test = tests()\n",
767
    "test.test5_4_count_gender()"
768
   ]
769
  },
770
  {
771
   "cell_type": "markdown",
772
   "id": "5bf4ca7a",
773
   "metadata": {},
774
   "source": [
775
    "### Test T7 ###"
776
   ]
777
  },
778
  {
779
   "cell_type": "markdown",
780
   "id": "936f87bc",
781
   "metadata": {},
782
   "source": [
783
    "**7.1** Creates extractive summary of an EHR with Naive Bayes Algorithm trained on PubMed articles."
784
   ]
785
  },
786
  {
787
   "cell_type": "code",
788
   "execution_count": null,
789
   "id": "40764971",
790
   "metadata": {},
791
   "outputs": [],
792
   "source": [
793
    "class tests(tests):\n",
794
    "    def test7_1_naive_bayes(self):\n",
795
    "        from pubmed_naive_bayes import classify_nb\n",
796
    "        from get_pubmed_nb_data import build_vecs\n",
797
    "        from sklearn.naive_bayes import GaussianNB\n",
798
    "\n",
799
    "        doc_id, text = select_ehr(self.ehrdb)\n",
800
    "        body_type = input('Use Naive Bayes model trained from whole body sections or just their body introductions?\\n\\t'\\\n",
801
    "                        '[w=whole body, j=just intro, DEFAULT=just intro]: ')\n",
802
    "        if body_type == 'w':\n",
803
    "            ending = 'body'\n",
804
    "        elif body_type in ['j', '']:\n",
805
    "            ending = 'intro'\n",
806
    "        else:\n",
807
    "            sys.exit('Error: Must input \\'w\\' or \\'j.\\'')\n",
808
    "        SUMM_DIR = os.path.abspath(os.path.join(os.path.dirname('EHRKit'), '..', 'summarization', 'pubmed_summarization'))\n",
809
    "        best_dir_name = get_nb_dir(ending, SUMM_DIR)\n",
810
    "        if not best_dir_name:\n",
811
    "            message = 'No Naive Bayes models of this type have been fit. '\\\n",
812
    "                        'Would you like to do so now?\\n\\t[DEFAULT=Yes] '\n",
813
    "            response = input(message)\n",
814
    "            if response.lower() in ['y', 'yes', '']:\n",
815
    "                command = 'python ' + os.path.abspath(os.path.join(os.path.dirname('EHRKit'), '..', 'summarization', 'pubmed_summarization', 'pubmed_naive_bayes.py'))\n",
816
    "                os.system(command)\n",
817
    "                best_dir_name = get_nb_dir(ending)\n",
818
    "            if response.lower() not in ['y', 'yes', ''] or not best_dir_name:\n",
819
    "                sys.exit('Exiting.')\n",
820
    "\n",
821
    "        # Fits model to data        \n",
822
    "        NB_DIR = os.path.join(SUMM_DIR, best_dir_name, 'nb')\n",
823
    "        with open(os.path.join(NB_DIR, 'feature_vecs.json'), 'r') as f:\n",
824
    "            data = json.load(f)\n",
825
    "        xtrain, ytrain = data['train_features'], data['train_outputs']\n",
826
    "        gnb = GaussianNB()\n",
827
    "        gnb.fit(xtrain, ytrain)\n",
828
    "\n",
829
    "        # Evaluates on model\n",
830
    "        tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')\n",
831
    "        feature_vecs, _ = build_vecs(text, None, tokenizer)\n",
832
    "        PCT_SUM = 0.3\n",
833
    "        preds = classify_nb(feature_vecs, PCT_SUM, gnb)\n",
834
    "        sents = tokenizer.tokenize(text)\n",
835
    "        summary = ''\n",
836
    "        for i in range(len(preds)):\n",
837
    "            if preds[i] == 1:\n",
838
    "                summary += sents[i]\n",
839
    "\n",
840
    "        show_summary(doc_id, text, summary, 'Naive Bayes')"
841
   ]
842
  },
843
  {
844
   "cell_type": "code",
845
   "execution_count": null,
846
   "id": "b8e05bc5",
847
   "metadata": {},
848
   "outputs": [],
849
   "source": [
850
    "test = tests()\n",
851
    "test.test7_1_naive_bayes()"
852
   ]
853
  },
854
  {
855
   "cell_type": "markdown",
856
   "id": "fa265304",
857
   "metadata": {},
858
   "source": [
859
    "**7.2** Generates abstractive summary of an EHR with pre-trained Distilbart model from Huggingface."
860
   ]
861
  },
862
  {
863
   "cell_type": "code",
864
   "execution_count": null,
865
   "id": "992ce1f2",
866
   "metadata": {},
867
   "outputs": [],
868
   "source": [
869
    "class tests(tests):\n",
870
    "    def test7_2_distilbart_summary(self):\n",
871
    "            # Distilbart for summarization. Trained on CNN/ Daily Mail (~4x longer summaries than XSum)\n",
872
    "            doc_id, text = select_ehr(self.ehrdb, requires_long=True)\n",
873
    "            model_name = 'sshleifer/distilbart-cnn-12-6'\n",
874
    "            summary = self.ehrdb.summarize_huggingface(text, model_name)\n",
875
    "\n",
876
    "            show_summary(doc_id, text, summary, model_name)\n",
877
    "            print('Number of Words in Full EHR: %d' % len(text.split()))\n",
878
    "            print('Number of Words in %s Summary: %d' % (model_name, len(summary.split())))\n"
879
   ]
880
  },
881
  {
882
   "cell_type": "code",
883
   "execution_count": null,
884
   "id": "3bd82bbd",
885
   "metadata": {},
886
   "outputs": [],
887
   "source": [
888
    "test = tests()\n",
889
    "test.test7_2_distilbart_summary()"
890
   ]
891
  },
892
  {
893
   "cell_type": "markdown",
894
   "id": "1f678ea5",
895
   "metadata": {},
896
   "source": [
897
    "**7.3** Generates abstractive summary of an EHR with pre-trained T5 model from Huggingface."
898
   ]
899
  },
900
  {
901
   "cell_type": "code",
902
   "execution_count": null,
903
   "id": "27653251",
904
   "metadata": {},
905
   "outputs": [],
906
   "source": [
907
    "class tests(tests):\n",
908
    "    def test7_3_t5_summary(self):\n",
909
    "            # T5 for summarization. Trained on CNN/ Daily Mail\n",
910
    "            doc_id, text = select_ehr(self.ehrdb, requires_long=True)\n",
911
    "            model_name = 't5-small'\n",
912
    "            summary = self.ehrdb.summarize_huggingface(text, model_name)\n",
913
    "\n",
914
    "            show_summary(doc_id, text, summary, model_name)\n",
915
    "            print('Number of Words in Full EHR: %d' % len(text.split()))\n",
916
    "            print('Number of Words in %s Summary: %d' % (model_name, len(summary.split())))"
917
   ]
918
  },
919
  {
920
   "cell_type": "code",
921
   "execution_count": null,
922
   "id": "e60e48d7",
923
   "metadata": {},
924
   "outputs": [],
925
   "source": [
926
    "test = tests()\n",
927
    "test.test7_3_t5_summary()"
928
   ]
929
  }
930
 ],
931
 "metadata": {
932
  "kernelspec": {
933
   "display_name": "Python 3 (ipykernel)",
934
   "language": "python",
935
   "name": "python3"
936
  },
937
  "language_info": {
938
   "codemirror_mode": {
939
    "name": "ipython",
940
    "version": 3
941
   },
942
   "file_extension": ".py",
943
   "mimetype": "text/x-python",
944
   "name": "python",
945
   "nbconvert_exporter": "python",
946
   "pygments_lexer": "ipython3",
947
   "version": "3.10.2"
948
  }
949
 },
950
 "nbformat": 4,
951
 "nbformat_minor": 5
952
}