|
a |
|
b/tools/tabtools.py |
|
|
1 |
import pandas as pd |
|
|
2 |
import jsonlines |
|
|
3 |
import json |
|
|
4 |
import re |
|
|
5 |
import sqlite3 |
|
|
6 |
import sys |
|
|
7 |
import Levenshtein |
|
|
8 |
def db_loader(target_ehr): |
|
|
9 |
ehr_dict = {"admissions":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/ADMISSIONS.csv", |
|
|
10 |
"chartevents":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/CHARTEVENTS.csv", |
|
|
11 |
"cost":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/COST.csv", |
|
|
12 |
"d_icd_diagnoses":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/D_ICD_DIAGNOSES.csv", |
|
|
13 |
"d_icd_procedures":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/D_ICD_PROCEDURES.csv", |
|
|
14 |
"d_items":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/D_ITEMS.csv", |
|
|
15 |
"d_labitems":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/D_LABITEMS.csv", |
|
|
16 |
"diagnoses_icd":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/DIAGNOSES_ICD.csv", |
|
|
17 |
"icustays":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/ICUSTAYS.csv", |
|
|
18 |
"inputevents_cv":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/INPUTEVENTS_CV.csv", |
|
|
19 |
"labevents":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/LABEVENTS.csv", |
|
|
20 |
"microbiologyevents":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/MICROBIOLOGYEVENTS.csv", |
|
|
21 |
"outputevents":"<YOUR_DATASET_PATH>/mimic_iii/OUTPUTEVENTS.csv", |
|
|
22 |
"patients":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/PATIENTS.csv", |
|
|
23 |
"prescriptions":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/PRESCRIPTIONS.csv", |
|
|
24 |
"procedures_icd":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/PROCEDURES_ICD.csv", |
|
|
25 |
"transfers":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/TRANSFERS.csv", |
|
|
26 |
} |
|
|
27 |
data = pd.read_csv(ehr_dict[target_ehr]) |
|
|
28 |
# data = data.astype(str) |
|
|
29 |
column_names = ', '.join(data.columns.tolist()) |
|
|
30 |
return data |
|
|
31 |
# def get_column_names(self, target_db): |
|
|
32 |
# return ', '.join(data.columns.tolist()) |
|
|
33 |
|
|
|
34 |
def data_filter(data, argument): |
|
|
35 |
# commands = re.sub(r' ', '', argument) |
|
|
36 |
backup_data = data |
|
|
37 |
# print('-->', argument) |
|
|
38 |
commands = argument.split('||') |
|
|
39 |
for i in range(len(commands)): |
|
|
40 |
try: |
|
|
41 |
# commands[i] = commands[i].replace(' ', '') |
|
|
42 |
if '>=' in commands[i]: |
|
|
43 |
command = commands[i].split('>=') |
|
|
44 |
column_name = command[0] |
|
|
45 |
value = command[1] |
|
|
46 |
try: |
|
|
47 |
value = type(data[column_name][0])(value) |
|
|
48 |
except: |
|
|
49 |
value = value |
|
|
50 |
data = data[data[column_name] >= value] |
|
|
51 |
elif '<=' in commands[i]: |
|
|
52 |
command = commands[i].split('<=') |
|
|
53 |
column_name = command[0] |
|
|
54 |
value = command[1] |
|
|
55 |
try: |
|
|
56 |
value = type(data[column_name][0])(value) |
|
|
57 |
except: |
|
|
58 |
value = value |
|
|
59 |
data = data[data[column_name] <= value] |
|
|
60 |
elif '>' in commands[i]: |
|
|
61 |
command = commands[i].split('>') |
|
|
62 |
column_name = command[0] |
|
|
63 |
value = command[1] |
|
|
64 |
try: |
|
|
65 |
value = type(data[column_name][0])(value) |
|
|
66 |
except: |
|
|
67 |
value = value |
|
|
68 |
data = data[data[column_name] > value] |
|
|
69 |
elif '<' in commands[i]: |
|
|
70 |
command = commands[i].split('<') |
|
|
71 |
column_name = command[0] |
|
|
72 |
value = command[1] |
|
|
73 |
if value[0] == "'" or value[0] == '"': |
|
|
74 |
value = value[1:-1] |
|
|
75 |
try: |
|
|
76 |
value = type(data[column_name][0])(value) |
|
|
77 |
except: |
|
|
78 |
value = value |
|
|
79 |
data = data[data[column_name] < value] |
|
|
80 |
elif '=' in commands[i]: |
|
|
81 |
command = commands[i].split('=') |
|
|
82 |
column_name = command[0] |
|
|
83 |
value = command[1] |
|
|
84 |
# print(command) |
|
|
85 |
# print(value) |
|
|
86 |
if value[0] == "'" or value[0] == '"': |
|
|
87 |
value = value[1:-1] |
|
|
88 |
try: |
|
|
89 |
examplar = backup_data[column_name].tolist()[0] |
|
|
90 |
value = type(examplar)(value) |
|
|
91 |
# print(value, type(value), type(examplar)) |
|
|
92 |
except: |
|
|
93 |
value = value |
|
|
94 |
# print('--', value, type(value), type(examplar)) |
|
|
95 |
# print('------', len(data)) |
|
|
96 |
data = data[data[column_name] == value] |
|
|
97 |
# print('======', len(data)) |
|
|
98 |
elif ' in ' in commands[i]: |
|
|
99 |
command = commands[i].split(' in ') |
|
|
100 |
column_name = command[0] |
|
|
101 |
value = command[1] |
|
|
102 |
value_list = [s.strip() for s in value.strip("[]").split(',')] |
|
|
103 |
value_list = [s.strip("'").strip('"') for s in value_list] |
|
|
104 |
# print(command) |
|
|
105 |
# print(column_name) |
|
|
106 |
# print(value) |
|
|
107 |
# print(value_list) |
|
|
108 |
value_list = list(map(type(data[column_name][0]), value_list)) |
|
|
109 |
# print(len(data)) |
|
|
110 |
data = data[data[column_name].isin(value_list)] |
|
|
111 |
# print(len(data)) |
|
|
112 |
elif 'max' in commands[i]: |
|
|
113 |
command = commands[i].split('max(') |
|
|
114 |
column_name = command[1].split(')')[0] |
|
|
115 |
data = data[data[column_name] == data[column_name].max()] |
|
|
116 |
elif 'min' in commands[i]: |
|
|
117 |
command = commands[i].split('min(') |
|
|
118 |
column_name = command[1].split(')')[0] |
|
|
119 |
data = data[data[column_name] == data[column_name].min()] |
|
|
120 |
except: |
|
|
121 |
if column_name not in data.columns.tolist(): |
|
|
122 |
columns = ', '.join(data.columns.tolist()) |
|
|
123 |
raise Exception("The filtering query {} is incorrect. Please modify the column name or use LoadDB to read another table. The column names in the current DB are {}.".format(commands[i], columns)) |
|
|
124 |
if column_name == '' or value == '': |
|
|
125 |
raise Exception("The filtering query {} is incorrect. There is syntax error in the command. Please modify the condition or use LoadDB to read another table.".format(commands[i])) |
|
|
126 |
if len(data) == 0: |
|
|
127 |
# get 5 examples from the backup data what is in the same column |
|
|
128 |
column_values = list(set(backup_data[column_name].tolist())) |
|
|
129 |
if ('=' in commands[i]) and (not value in column_values) and (not '>=' in commands[i]) and (not '<=' in commands[i]): |
|
|
130 |
levenshtein_dist = {} |
|
|
131 |
for cv in column_values: |
|
|
132 |
levenshtein_dist[cv] = Levenshtein.distance(str(cv), str(value)) |
|
|
133 |
levenshtein_dist = sorted(levenshtein_dist.items(), key=lambda x: x[1], reverse=False) |
|
|
134 |
column_values = [i[0] for i in levenshtein_dist[:5]] |
|
|
135 |
column_values = ', '.join([str(i) for i in column_values]) |
|
|
136 |
raise Exception("The filtering query {} is incorrect. There is no {} value in the column. Five example values in the column are {}. Please check if you get the correct {} value.".format(commands[i], value, column_values, column_name)) |
|
|
137 |
else: |
|
|
138 |
return data |
|
|
139 |
return data |
|
|
140 |
|
|
|
141 |
def get_value(data, argument): |
|
|
142 |
try: |
|
|
143 |
commands = argument.split(', ') |
|
|
144 |
if len(commands) == 1: |
|
|
145 |
column = argument |
|
|
146 |
while column[0] == '[' or column[0] == "'": |
|
|
147 |
column = column[1:] |
|
|
148 |
while column[-1] == ']' or column[-1] == "'": |
|
|
149 |
column = column[:-1] |
|
|
150 |
if len(data) == 1: |
|
|
151 |
return str(data.iloc[0][column]) |
|
|
152 |
else: |
|
|
153 |
answer_list = list(set(data[column].tolist())) |
|
|
154 |
answer_list = [str(i) for i in answer_list] |
|
|
155 |
return ', '.join(answer_list) |
|
|
156 |
# else: |
|
|
157 |
# return "Get the value. But there are too many returned values. Please double-check the code and make necessary changes." |
|
|
158 |
else: |
|
|
159 |
column = commands[0] |
|
|
160 |
if 'mean' in commands[-1]: |
|
|
161 |
res_list = data[column].tolist() |
|
|
162 |
res_list = [float(i) for i in res_list] |
|
|
163 |
return sum(res_list)/len(res_list) |
|
|
164 |
elif 'max' in commands[-1]: |
|
|
165 |
res_list = data[column].tolist() |
|
|
166 |
try: |
|
|
167 |
res_list = [float(i) for i in res_list] |
|
|
168 |
except: |
|
|
169 |
res_list = [str(i) for i in res_list] |
|
|
170 |
return max(res_list) |
|
|
171 |
elif 'min' in commands[-1]: |
|
|
172 |
res_list = data[column].tolist() |
|
|
173 |
try: |
|
|
174 |
res_list = [float(i) for i in res_list] |
|
|
175 |
except: |
|
|
176 |
res_list = [str(i) for i in res_list] |
|
|
177 |
return min(res_list) |
|
|
178 |
elif 'sum' in commands[-1]: |
|
|
179 |
res_list = data[column].tolist() |
|
|
180 |
res_list = [float(i) for i in res_list] |
|
|
181 |
return sum(res_list) |
|
|
182 |
elif 'list' in commands[-1]: |
|
|
183 |
res_list = data[column].tolist() |
|
|
184 |
res_list = [str(i) for i in res_list] |
|
|
185 |
return list(res_list) |
|
|
186 |
else: |
|
|
187 |
raise Exception("The operation {} contains syntax errors. Please check the arguments.".format(commands[-1])) |
|
|
188 |
except: |
|
|
189 |
column_values = ', '.join(data.columns.tolist()) |
|
|
190 |
raise Exception("The column name {} is incorrect. Please check the column name and make necessary changes. The columns in this table include {}.".format(column, column_values)) |
|
|
191 |
|
|
|
192 |
def sql_interpreter(command): |
|
|
193 |
con = sqlite3.connect("<YOUR_DATASET_PATH>/ehrsql/mimic_iii/mimic_iii.db") |
|
|
194 |
cur = con.cursor() |
|
|
195 |
results = cur.execute(command).fetchall() |
|
|
196 |
return results |
|
|
197 |
|
|
|
198 |
def date_calculator(argument): |
|
|
199 |
try: |
|
|
200 |
con = sqlite3.connect("<YOUR_DATASET_PATH>/ehrsql/mimic_iii/mimic_iii.db") |
|
|
201 |
cur = con.cursor() |
|
|
202 |
command = "select datetime(current_time, '{}')".format(argument) |
|
|
203 |
results = cur.execute(command).fetchall()[0][0] |
|
|
204 |
except: |
|
|
205 |
raise Exception("The date calculator {} is incorrect. Please check the syntax and make necessary changes. For the current date and time, please call Calendar('0 year').".format(argument)) |
|
|
206 |
return results |
|
|
207 |
|
|
|
208 |
if __name__ == "__main__": |
|
|
209 |
db = table_toolkits() |
|
|
210 |
print(db.db_loader("microbiologyevents")) |
|
|
211 |
# print(db.data_filter("SPEC_TYPE_DESC=peripheral blood lymphocytes")) |
|
|
212 |
print(db.data_filter("HADM_ID=107655")) |
|
|
213 |
print(db.data_filter("SPEC_TYPE_DESC=peripheral blood lymphocytes")) |
|
|
214 |
print(db.get_value('CHARTTIME')) |
|
|
215 |
# results = db.sql_interpreter("select max(t1.c1) from ( select sum(cost.cost) as c1 from cost where cost.hadm_id in ( select diagnoses_icd.hadm_id from diagnoses_icd where diagnoses_icd.icd9_code = ( select d_icd_diagnoses.icd9_code from d_icd_diagnoses where d_icd_diagnoses.short_title = 'comp-oth vasc dev/graft' ) ) and datetime(cost.chargetime) >= datetime(current_time,'-1 year') group by cost.hadm_id ) as t1") |
|
|
216 |
# results = [result[0] for result in results] |
|
|
217 |
# if len(results) == 1: |
|
|
218 |
# print(results[0]) |
|
|
219 |
# else: |
|
|
220 |
# print(results) |
|
|
221 |
# print(db.date_calculator('-1 year')) |