a b/tools/tabtools.py
1
import pandas as pd
2
import jsonlines
3
import json
4
import re
5
import sqlite3
6
import sys
7
import Levenshtein
8
def db_loader(target_ehr):
9
    ehr_dict = {"admissions":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/ADMISSIONS.csv",
10
                "chartevents":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/CHARTEVENTS.csv",
11
                "cost":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/COST.csv",
12
                "d_icd_diagnoses":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/D_ICD_DIAGNOSES.csv",
13
                "d_icd_procedures":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/D_ICD_PROCEDURES.csv",
14
                "d_items":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/D_ITEMS.csv",
15
                "d_labitems":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/D_LABITEMS.csv",
16
                "diagnoses_icd":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/DIAGNOSES_ICD.csv",
17
                "icustays":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/ICUSTAYS.csv",
18
                "inputevents_cv":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/INPUTEVENTS_CV.csv",
19
                "labevents":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/LABEVENTS.csv",
20
                "microbiologyevents":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/MICROBIOLOGYEVENTS.csv",
21
                "outputevents":"<YOUR_DATASET_PATH>/mimic_iii/OUTPUTEVENTS.csv",
22
                "patients":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/PATIENTS.csv",
23
                "prescriptions":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/PRESCRIPTIONS.csv",
24
                "procedures_icd":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/PROCEDURES_ICD.csv",
25
                "transfers":"<YOUR_DATASET_PATH>/ehrsql/mimic_iii/TRANSFERS.csv",
26
                }
27
    data = pd.read_csv(ehr_dict[target_ehr])
28
    # data = data.astype(str)
29
    column_names = ', '.join(data.columns.tolist())
30
    return data
31
# def get_column_names(self, target_db):
32
#     return ', '.join(data.columns.tolist())
33
34
def data_filter(data, argument):
35
    # commands = re.sub(r' ', '', argument)
36
    backup_data = data
37
    # print('-->', argument)
38
    commands = argument.split('||')
39
    for i in range(len(commands)):
40
        try:
41
            # commands[i] = commands[i].replace(' ', '')
42
            if '>=' in commands[i]:
43
                command = commands[i].split('>=')
44
                column_name = command[0]
45
                value = command[1]
46
                try:
47
                    value = type(data[column_name][0])(value)
48
                except:
49
                    value = value
50
                data = data[data[column_name] >= value]
51
            elif '<=' in commands[i]:
52
                command = commands[i].split('<=')
53
                column_name = command[0]
54
                value = command[1]
55
                try:
56
                    value = type(data[column_name][0])(value)
57
                except:
58
                    value = value
59
                data = data[data[column_name] <= value]
60
            elif '>' in commands[i]:
61
                command = commands[i].split('>')
62
                column_name = command[0]
63
                value = command[1]
64
                try:
65
                    value = type(data[column_name][0])(value)
66
                except:
67
                    value = value
68
                data = data[data[column_name] > value]
69
            elif '<' in commands[i]:
70
                command = commands[i].split('<')
71
                column_name = command[0]
72
                value = command[1]
73
                if value[0] == "'" or value[0] == '"':
74
                    value = value[1:-1]
75
                try:
76
                    value = type(data[column_name][0])(value)
77
                except:
78
                    value = value
79
                data = data[data[column_name] < value]
80
            elif '=' in commands[i]:
81
                command = commands[i].split('=')
82
                column_name = command[0]
83
                value = command[1]
84
                # print(command)
85
                # print(value)
86
                if value[0] == "'" or value[0] == '"':
87
                    value = value[1:-1]
88
                try:
89
                    examplar = backup_data[column_name].tolist()[0]
90
                    value = type(examplar)(value)
91
                    # print(value, type(value), type(examplar))
92
                except:
93
                    value = value
94
                    # print('--', value, type(value), type(examplar))
95
                # print('------', len(data))
96
                data = data[data[column_name] == value]
97
                # print('======', len(data))
98
            elif ' in ' in commands[i]:
99
                command = commands[i].split(' in ')
100
                column_name = command[0]
101
                value = command[1]
102
                value_list = [s.strip() for s in value.strip("[]").split(',')]
103
                value_list = [s.strip("'").strip('"') for s in value_list]
104
                # print(command)
105
                # print(column_name)
106
                # print(value)
107
                # print(value_list)
108
                value_list = list(map(type(data[column_name][0]), value_list))
109
                # print(len(data))
110
                data = data[data[column_name].isin(value_list)]
111
                # print(len(data))
112
            elif 'max' in commands[i]:
113
                command = commands[i].split('max(')
114
                column_name = command[1].split(')')[0]
115
                data = data[data[column_name] == data[column_name].max()]
116
            elif 'min' in commands[i]:
117
                command = commands[i].split('min(')
118
                column_name = command[1].split(')')[0]
119
                data = data[data[column_name] == data[column_name].min()]
120
        except:
121
            if column_name not in data.columns.tolist():
122
                columns = ', '.join(data.columns.tolist())
123
                raise Exception("The filtering query {} is incorrect. Please modify the column name or use LoadDB to read another table. The column names in the current DB are {}.".format(commands[i], columns))
124
            if column_name == '' or value == '':
125
                raise Exception("The filtering query {} is incorrect. There is syntax error in the command. Please modify the condition or use LoadDB to read another table.".format(commands[i]))
126
        if len(data) == 0:
127
            # get 5 examples from the backup data what is in the same column
128
            column_values = list(set(backup_data[column_name].tolist()))
129
            if ('=' in commands[i]) and (not value in column_values) and (not '>=' in commands[i]) and (not '<=' in commands[i]):
130
                levenshtein_dist = {}
131
                for cv in column_values:
132
                    levenshtein_dist[cv] = Levenshtein.distance(str(cv), str(value))
133
                levenshtein_dist = sorted(levenshtein_dist.items(), key=lambda x: x[1], reverse=False)
134
                column_values = [i[0] for i in levenshtein_dist[:5]]
135
                column_values = ', '.join([str(i) for i in column_values])
136
                raise Exception("The filtering query {} is incorrect. There is no {} value in the column. Five example values in the column are {}. Please check if you get the correct {} value.".format(commands[i], value, column_values, column_name))
137
            else:
138
                return data
139
    return data
140
141
def get_value(data, argument):
142
    try:
143
        commands = argument.split(', ')
144
        if len(commands) == 1:
145
            column = argument
146
            while column[0] == '[' or column[0] == "'":
147
                column = column[1:]
148
            while column[-1] == ']' or column[-1] == "'":
149
                column = column[:-1]
150
            if len(data) == 1:
151
                return str(data.iloc[0][column])
152
            else:
153
                answer_list = list(set(data[column].tolist()))
154
                answer_list = [str(i) for i in answer_list]
155
                return ', '.join(answer_list)
156
                # else:
157
                #     return "Get the value. But there are too many returned values. Please double-check the code and make necessary changes."
158
        else:
159
            column = commands[0]
160
            if 'mean' in commands[-1]:
161
                res_list = data[column].tolist()
162
                res_list = [float(i) for i in res_list]
163
                return sum(res_list)/len(res_list)
164
            elif 'max' in commands[-1]:
165
                res_list = data[column].tolist()
166
                try:
167
                    res_list = [float(i) for i in res_list]
168
                except:
169
                    res_list = [str(i) for i in res_list]
170
                return max(res_list)
171
            elif 'min' in commands[-1]:
172
                res_list = data[column].tolist()
173
                try:
174
                    res_list = [float(i) for i in res_list]
175
                except:
176
                    res_list = [str(i) for i in res_list]
177
                return min(res_list)
178
            elif 'sum' in commands[-1]:
179
                res_list = data[column].tolist()
180
                res_list = [float(i) for i in res_list]
181
                return sum(res_list)
182
            elif 'list' in commands[-1]:
183
                res_list = data[column].tolist()
184
                res_list = [str(i) for i in res_list]
185
                return list(res_list)
186
            else:
187
                raise Exception("The operation {} contains syntax errors. Please check the arguments.".format(commands[-1]))
188
    except:
189
        column_values = ', '.join(data.columns.tolist())
190
        raise Exception("The column name {} is incorrect. Please check the column name and make necessary changes. The columns in this table include {}.".format(column, column_values))
191
192
def sql_interpreter(command):
193
    con = sqlite3.connect("<YOUR_DATASET_PATH>/ehrsql/mimic_iii/mimic_iii.db")
194
    cur = con.cursor()
195
    results = cur.execute(command).fetchall()
196
    return results
197
198
def date_calculator(argument):
199
    try:
200
        con = sqlite3.connect("<YOUR_DATASET_PATH>/ehrsql/mimic_iii/mimic_iii.db")
201
        cur = con.cursor()
202
        command = "select datetime(current_time, '{}')".format(argument)
203
        results = cur.execute(command).fetchall()[0][0]
204
    except:
205
        raise Exception("The date calculator {} is incorrect. Please check the syntax and make necessary changes. For the current date and time, please call Calendar('0 year').".format(argument))
206
    return results
207
208
if __name__ == "__main__":
209
    db = table_toolkits()
210
    print(db.db_loader("microbiologyevents"))
211
    # print(db.data_filter("SPEC_TYPE_DESC=peripheral blood lymphocytes"))
212
    print(db.data_filter("HADM_ID=107655"))
213
    print(db.data_filter("SPEC_TYPE_DESC=peripheral blood lymphocytes"))
214
    print(db.get_value('CHARTTIME'))
215
    # results = db.sql_interpreter("select max(t1.c1) from ( select sum(cost.cost) as c1 from cost where cost.hadm_id in ( select diagnoses_icd.hadm_id from diagnoses_icd where diagnoses_icd.icd9_code = ( select d_icd_diagnoses.icd9_code from d_icd_diagnoses where d_icd_diagnoses.short_title = 'comp-oth vasc dev/graft' ) ) and datetime(cost.chargetime) >= datetime(current_time,'-1 year') group by cost.hadm_id ) as t1")
216
    # results = [result[0] for result in results]
217
    # if len(results) == 1:
218
    #     print(results[0])
219
    # else:
220
    #     print(results)
221
    # print(db.date_calculator('-1 year'))