|
a |
|
b/mimicsql/evaluation/utils.py |
|
|
1 |
import re |
|
|
2 |
import csv |
|
|
3 |
import pandas |
|
|
4 |
import sqlite3 |
|
|
5 |
import random |
|
|
6 |
import json |
|
|
7 |
import itertools |
|
|
8 |
|
|
|
9 |
class query(object): |
|
|
10 |
|
|
|
11 |
def __init__(self, db_file): |
|
|
12 |
|
|
|
13 |
self.db_meta, self.db_tabs, self.db_head = self._load_db(db_file) |
|
|
14 |
self.agg_op = ['', 'count', 'max', 'min', 'avg'] |
|
|
15 |
self.cond_op = ['=', '>', '<', '>=', '<='] |
|
|
16 |
|
|
|
17 |
def __call__(self, sql_): |
|
|
18 |
''' |
|
|
19 |
select $$$ ### from *** where === |
|
|
20 |
''' |
|
|
21 |
'''###''' |
|
|
22 |
mm_agg_col = [] |
|
|
23 |
for itm in sql_['agg_col']: |
|
|
24 |
tt = self.db_tabs[itm[0]] |
|
|
25 |
hh = self.db_head[tt][itm[1]] |
|
|
26 |
mm_agg_col.append('.'.join([tt, hh])) |
|
|
27 |
mm_agg_col = ','.join(mm_agg_col) |
|
|
28 |
'''$$$''' |
|
|
29 |
if sql_['sel'] == 0: |
|
|
30 |
mm_agg = '{}'.format(mm_agg_col) |
|
|
31 |
elif sql_['sel'] == 1: |
|
|
32 |
mm_agg = 'COUNT ( DISTINCT {} )'.format(mm_agg_col) |
|
|
33 |
elif sql_['sel'] == 2: |
|
|
34 |
mm_agg = 'MAX ( {} )'.format(mm_agg_col) |
|
|
35 |
elif sql_['sel'] == 3: |
|
|
36 |
mm_agg = 'MIN ( {} )'.format(mm_agg_col) |
|
|
37 |
elif sql_['sel'] == 4: |
|
|
38 |
mm_agg = 'AVG ( {} )'.format(mm_agg_col) |
|
|
39 |
'''***''' |
|
|
40 |
tbtb = [self.db_tabs[k] for k in sql_['table']] |
|
|
41 |
mm_tab = [tbtb[0]] |
|
|
42 |
for k in range(1, len(tbtb)): |
|
|
43 |
mm_tab.append('INNER JOIN') |
|
|
44 |
mm_tab.append(tbtb[k]) |
|
|
45 |
mm_tab.append('on') |
|
|
46 |
mm_tab.append('{}.{} = {}.{}'.format(tbtb[0], 'HADM_ID', tbtb[k], 'HADM_ID')) |
|
|
47 |
'''===''' |
|
|
48 |
mm_cond = [] |
|
|
49 |
for itm in sql_['cond']: |
|
|
50 |
tt = self.db_tabs[itm[0]] |
|
|
51 |
cc = self.db_head[tt][itm[1]] |
|
|
52 |
oo = self.cond_op[itm[2]] |
|
|
53 |
ff = itm[3] |
|
|
54 |
cond1 = '{}.{} {} {}'.format(tt, cc, oo, '"'+str(ff)+'"') |
|
|
55 |
mm_cond.append(cond1) |
|
|
56 |
mm_cond = ' AND '.join(mm_cond) |
|
|
57 |
bb_query = 'SELECT {} FROM {} WHERE {}'.format(mm_agg, ' '.join(mm_tab), mm_cond) |
|
|
58 |
|
|
|
59 |
return bb_query |
|
|
60 |
|
|
|
61 |
def _load_db(self, db_file): |
|
|
62 |
|
|
|
63 |
self.conn = sqlite3.connect(db_file) |
|
|
64 |
self.cur = self.conn.cursor() |
|
|
65 |
self.cur.execute("select * from sqlite_master;") |
|
|
66 |
results = self.cur.fetchall() |
|
|
67 |
db_meta = {} |
|
|
68 |
db_tabs = [] |
|
|
69 |
db_head = {} |
|
|
70 |
for tb in results: |
|
|
71 |
db_meta[tb[2]] = {} |
|
|
72 |
db_tabs.append(tb[2]) |
|
|
73 |
db_head[tb[2]] = {} |
|
|
74 |
arr = re.split('\n', tb[-1])[1:-1] |
|
|
75 |
dbaa = [] |
|
|
76 |
for itm in arr: |
|
|
77 |
ttl = re.split('\s', itm) |
|
|
78 |
ttl = list(filter(None, ttl)) |
|
|
79 |
db_meta[tb[2]][ttl[0]] = ttl[1] |
|
|
80 |
dbaa.append(ttl[0]) |
|
|
81 |
db_head[tb[2]] = dbaa |
|
|
82 |
|
|
|
83 |
return (db_meta, db_tabs, db_head) |
|
|
84 |
|
|
|
85 |
def execute_sql(self, sql_): |
|
|
86 |
return self.cur.execute(sql_) |
|
|
87 |
|
|
|
88 |
def get_value_pool_(db_file, model, samp_cond): |
|
|
89 |
(db_meta, db_tabs, db_head) = model._load_db(db_file) |
|
|
90 |
pool_ = [] |
|
|
91 |
for itm in samp_cond: |
|
|
92 |
mytb = db_tabs[itm[0]] |
|
|
93 |
myhd = db_head[mytb][itm[1]] |
|
|
94 |
mysql = 'select {} from {}'.format(myhd, mytb) |
|
|
95 |
myres = model.execute_sql(mysql).fetchall() |
|
|
96 |
myres = list({k[0]: {} for k in myres}) |
|
|
97 |
pool_.append(myres) |
|
|
98 |
|
|
|
99 |
return pool_ |