Diff of /sql2sql.py [000000] .. [ab27bc]

Switch to unified view

a b/sql2sql.py
1
import json
2
import os
3
import re
4
import numpy as np
5
import pandas as pd
6
import networkx as nx
7
from build_mimicsqlstar_db.schema_mimic import SCHEMA, MAP_WITH_MIMICSQL
8
from collections import OrderedDict
9
from itertools import repeat
10
11
12
class SQL2SQL:
13
    def __init__(self):
14
        self.schema_graph = nx.Graph()
15
        for k, vs in SCHEMA.items():
16
            for v in vs:
17
                self.schema_graph.add_edge(k, v[0])
18
        self.inner_join_template = 'INNER JOIN {} ON {} = {}'
19
20
    def find_table(self, block):
21
        cols = re.findall('[A-Z_]+[.]"[\dA-Z_]+"', block)
22
        return [c.split('.')[0] for c in cols]
23
24
    def from_caluse(self, new_select):
25
        cols = re.findall('[A-Z_]+[.]"[\dA-Z_]+"', new_select)
26
        for col in cols:
27
            t, c = col.split('.')
28
            return t
29
30
    def cols_clause(self, block):
31
        cols = re.findall('[A-Z_]+[.]"[\dA-Z_]+"', block)
32
        #print(cols)
33
        for col in cols:
34
            t, c = col.split('.')
35
            new_tc = MAP_WITH_MIMICSQL[t][c.replace('"', '')]
36
            #print(new_tc)
37
            block = block.replace(col, new_tc)
38
            # print(col, new_tc)
39
        return block
40
41
    def translate(self, sql):
42
        remain, where_block = sql.split('WHERE')
43
        select_block, remain = remain.split('FROM')
44
45
        new_select = self.cols_clause(select_block)
46
        select_tables = self.find_table(new_select)
47
        from_table = self.from_caluse(new_select)
48
        new_where = self.cols_clause(where_block)
49
        where_tables = self.find_table(new_where)
50
        # print(new_select)
51
        # print(from_table)
52
        # print(new_where)
53
        # print(where_tables)
54
        inner_join_blocks = list()
55
        tables = list(set(where_tables + select_tables))
56
        for wt in tables:
57
            path = nx.shortest_path(self.schema_graph, source=from_table.replace(' ', ''), target=wt.replace(' ', ''))
58
            for i in range(len(path) - 1):
59
                a = path[i]
60
                b = path[i+1]
61
                key = [con[1] for con in SCHEMA[a] if con[0] == b][0]
62
                inner_join_block = self.inner_join_template.format(b, f'{a}."{key}"', f'{b}."{key}"')
63
                #print(inner_join_block)
64
                inner_join_blocks.append(inner_join_block)
65
66
        inner_join_blocks = list(OrderedDict(zip(inner_join_blocks, repeat(None))))
67
        new_inner_join = ' '.join(list(inner_join_blocks))
68
        #print(new_inner_join)
69
70
        temp = 'FROM '.join([new_select, from_table])
71
        temp = ' '.join([temp, new_inner_join])
72
        temp = ' '.join(temp.split(' ')) + ' '
73
        new_sql = 'WHERE'.join([temp, new_where])
74
        return new_sql
75
76
77
if __name__ == '__main__':
78
    convertor = SQL2SQL()
79
    sql = 'SELECT DEMOGRAPHIC."GENDER",DEMOGRAPHIC."INSURANCE" FROM DEMOGRAPHIC WHERE DEMOGRAPHIC."SUBJECT_ID" = "81923"'
80
    new_sql = convertor.translate(sql)
81
    print(sql)
82
    print(new_sql)