[2d4573]: / extractiveSummarization / scripts / noteevents_lexrank.py

Download this file

110 lines (91 with data), 3.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
import random
import sys, os
import re
import shutil
import time
import argparse
from getpass import getpass
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
#print(sys.path)
from ehrkit import ehrkit
from ehrkit.summarizers import Lexrank
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# print(sys.path)
parser = argparse.ArgumentParser()
parser.add_argument('--saveto', action='store', metavar='dir_path', type=str, help='Directory path to store produced summaries', required=True)
parser.add_argument('--ntrain', action='store', type=int, help='First n number of patients to train on. default 100')
parser.add_argument('--ntest', action='store', type=int, help='First n number of patients to produce summaries for. default 20.')
parser.add_argument('--threshold', action='store', type=float, help="default 0.1")
args = parser.parse_args()
saveto_dir = args.saveto
ntrain = args.ntrain or 100
ntest = args.ntest or 20
threshold = args.threshold or 0.1
if not os.path.isdir(saveto_dir):
print('The saveto directory specified does not exist')
sys.exit()
# Number of patients in PATIENTS.
NUM_PATIENTS = 46520
start=time.time()
#ehrdb = ehrkit.start_session(USERNAME, PASSWORD)
ehrdb = ehrkit.start_session(input("User?"), getpass("Password?"))
ehrdb.get_patients(ntrain)
ehrdb.get_note_events()
#SUMMARY BY NOTE
print("----Summaries by note----")
new_dir = "script_summary_bynote"
new_dir_path = os.path.join(saveto_dir, new_dir)
if os.path.exists(new_dir_path):
shutil.rmtree(new_dir_path)
os.mkdir(new_dir_path)
allnotes = [note[1] for patient in ehrdb.patients.values() for note in patient.note_events]
lxr = Lexrank(allnotes, threshold=threshold)
for i, patient in enumerate(ehrdb.patients.values()):
patient_id = patient.id
notewise_sum = []
for note_id, note in patient.note_events:
if len(note) < 10:
summary_len = 2
elif len(note) > 100:
summary_len = len(note)//20
else:
summary_len = len(note)//10
note_summary = lxr.get_summary(note, summary_size=summary_len)
notewise_sum.extend(note_summary)
joined_summary = "\n".join(notewise_sum)
summary_path = os.path.join(new_dir_path, str(patient_id) + ".sum")
with open(summary_path, 'w') as sum:
sum.write(joined_summary)
if i == ntest:
break
end1 = time.time()
print("Runtime " + str(end1 - start))
#SUMMARY BY ENTIRE HISTORY
print("----Summaries by entire history----")
new_dir = "script_summary_byetirehistory"
new_dir_path = os.path.join(saveto_dir, new_dir)
if os.path.exists(new_dir_path):
shutil.rmtree(new_dir_path)
os.mkdir(new_dir_path)
allnotes_bypatient = {}
for patient in ehrdb.patients.values():
allnotes_bypatient[patient.id] = []
for note in patient.note_events:
allnotes_bypatient[patient.id].extend(note)
lxr2 = Lexrank(list(allnotes_bypatient.values()), threshold=threshold)
for i, patient_id in enumerate(allnotes_bypatient):
if len(allnotes_bypatient[patient_id]) < 4:
summary_len = 1
else:
summary_len = len(allnotes_bypatient[patient_id])//4
summary = lxr2.get_summary(allnotes_bypatient[patient_id], summary_size=summary_len)
joined_summary = "\n".join(summary[0])
summary_path = os.path.join(new_dir_path, str(patient_id) + ".sum")
with open(summary_path, 'w') as sum:
sum.write(joined_summary)
if i == ntest:
break
end2 = time.time()
print("Runtime " + str(end2-end1))
print("------------------------------")
print("total runtime " + str(end2-start))