a b/mimicsql/evaluation/utils.py
1
import re
2
import csv
3
import pandas
4
import sqlite3
5
import random
6
import json
7
import itertools
8
9
class query(object):
10
    
11
    def __init__(self, db_file):
12
        
13
        self.db_meta, self.db_tabs, self.db_head = self._load_db(db_file)
14
        self.agg_op = ['', 'count', 'max', 'min', 'avg']
15
        self.cond_op = ['=', '>', '<', '>=', '<=']
16
    
17
    def __call__(self, sql_):
18
        '''
19
        select $$$ ### from *** where ===
20
        '''
21
        '''###'''
22
        mm_agg_col = []
23
        for itm in sql_['agg_col']:
24
            tt = self.db_tabs[itm[0]]
25
            hh = self.db_head[tt][itm[1]]
26
            mm_agg_col.append('.'.join([tt, hh]))
27
        mm_agg_col = ','.join(mm_agg_col)
28
        '''$$$'''
29
        if sql_['sel'] == 0:
30
            mm_agg = '{}'.format(mm_agg_col)
31
        elif sql_['sel'] == 1:
32
            mm_agg = 'COUNT ( DISTINCT {} )'.format(mm_agg_col)
33
        elif sql_['sel'] == 2:
34
            mm_agg = 'MAX ( {} )'.format(mm_agg_col)
35
        elif sql_['sel'] == 3:
36
            mm_agg = 'MIN ( {} )'.format(mm_agg_col)
37
        elif sql_['sel'] == 4:
38
            mm_agg = 'AVG ( {} )'.format(mm_agg_col)
39
        '''***'''
40
        tbtb = [self.db_tabs[k] for k in sql_['table']]
41
        mm_tab = [tbtb[0]]
42
        for k in range(1, len(tbtb)):
43
            mm_tab.append('INNER JOIN')
44
            mm_tab.append(tbtb[k])
45
            mm_tab.append('on')
46
            mm_tab.append('{}.{} = {}.{}'.format(tbtb[0], 'HADM_ID', tbtb[k], 'HADM_ID'))
47
        '''==='''
48
        mm_cond = []
49
        for itm in sql_['cond']:
50
            tt = self.db_tabs[itm[0]]
51
            cc = self.db_head[tt][itm[1]]
52
            oo = self.cond_op[itm[2]]
53
            ff = itm[3]
54
            cond1 = '{}.{} {} {}'.format(tt, cc, oo, '"'+str(ff)+'"')
55
            mm_cond.append(cond1)
56
        mm_cond = ' AND '.join(mm_cond)
57
        bb_query = 'SELECT {} FROM {} WHERE {}'.format(mm_agg, ' '.join(mm_tab), mm_cond)
58
                
59
        return bb_query
60
    
61
    def _load_db(self, db_file):
62
        
63
        self.conn = sqlite3.connect(db_file)
64
        self.cur = self.conn.cursor()
65
        self.cur.execute("select * from sqlite_master;")
66
        results = self.cur.fetchall()
67
        db_meta = {}
68
        db_tabs = []
69
        db_head = {}
70
        for tb in results:
71
            db_meta[tb[2]] = {}
72
            db_tabs.append(tb[2])
73
            db_head[tb[2]] = {}
74
            arr = re.split('\n', tb[-1])[1:-1]
75
            dbaa = []
76
            for itm in arr:
77
                ttl = re.split('\s', itm)
78
                ttl = list(filter(None, ttl))
79
                db_meta[tb[2]][ttl[0]] = ttl[1]
80
                dbaa.append(ttl[0])
81
            db_head[tb[2]] = dbaa
82
83
        return (db_meta, db_tabs, db_head)
84
    
85
    def execute_sql(self, sql_):
86
        return self.cur.execute(sql_)
87
88
def get_value_pool_(db_file, model, samp_cond):
89
    (db_meta, db_tabs, db_head) = model._load_db(db_file)
90
    pool_ = []
91
    for itm in samp_cond:
92
        mytb = db_tabs[itm[0]]
93
        myhd = db_head[mytb][itm[1]]
94
        mysql = 'select {} from {}'.format(myhd, mytb)
95
        myres = model.execute_sql(mysql).fetchall()
96
        myres = list({k[0]: {} for k in myres})
97
        pool_.append(myres)
98
        
99
    return pool_