|
a |
|
b/sql2sql.py |
|
|
1 |
import json |
|
|
2 |
import os |
|
|
3 |
import re |
|
|
4 |
import numpy as np |
|
|
5 |
import pandas as pd |
|
|
6 |
import networkx as nx |
|
|
7 |
from build_mimicsqlstar_db.schema_mimic import SCHEMA, MAP_WITH_MIMICSQL |
|
|
8 |
from collections import OrderedDict |
|
|
9 |
from itertools import repeat |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
class SQL2SQL: |
|
|
13 |
def __init__(self): |
|
|
14 |
self.schema_graph = nx.Graph() |
|
|
15 |
for k, vs in SCHEMA.items(): |
|
|
16 |
for v in vs: |
|
|
17 |
self.schema_graph.add_edge(k, v[0]) |
|
|
18 |
self.inner_join_template = 'INNER JOIN {} ON {} = {}' |
|
|
19 |
|
|
|
20 |
def find_table(self, block): |
|
|
21 |
cols = re.findall('[A-Z_]+[.]"[\dA-Z_]+"', block) |
|
|
22 |
return [c.split('.')[0] for c in cols] |
|
|
23 |
|
|
|
24 |
def from_caluse(self, new_select): |
|
|
25 |
cols = re.findall('[A-Z_]+[.]"[\dA-Z_]+"', new_select) |
|
|
26 |
for col in cols: |
|
|
27 |
t, c = col.split('.') |
|
|
28 |
return t |
|
|
29 |
|
|
|
30 |
def cols_clause(self, block): |
|
|
31 |
cols = re.findall('[A-Z_]+[.]"[\dA-Z_]+"', block) |
|
|
32 |
#print(cols) |
|
|
33 |
for col in cols: |
|
|
34 |
t, c = col.split('.') |
|
|
35 |
new_tc = MAP_WITH_MIMICSQL[t][c.replace('"', '')] |
|
|
36 |
#print(new_tc) |
|
|
37 |
block = block.replace(col, new_tc) |
|
|
38 |
# print(col, new_tc) |
|
|
39 |
return block |
|
|
40 |
|
|
|
41 |
def translate(self, sql): |
|
|
42 |
remain, where_block = sql.split('WHERE') |
|
|
43 |
select_block, remain = remain.split('FROM') |
|
|
44 |
|
|
|
45 |
new_select = self.cols_clause(select_block) |
|
|
46 |
select_tables = self.find_table(new_select) |
|
|
47 |
from_table = self.from_caluse(new_select) |
|
|
48 |
new_where = self.cols_clause(where_block) |
|
|
49 |
where_tables = self.find_table(new_where) |
|
|
50 |
# print(new_select) |
|
|
51 |
# print(from_table) |
|
|
52 |
# print(new_where) |
|
|
53 |
# print(where_tables) |
|
|
54 |
inner_join_blocks = list() |
|
|
55 |
tables = list(set(where_tables + select_tables)) |
|
|
56 |
for wt in tables: |
|
|
57 |
path = nx.shortest_path(self.schema_graph, source=from_table.replace(' ', ''), target=wt.replace(' ', '')) |
|
|
58 |
for i in range(len(path) - 1): |
|
|
59 |
a = path[i] |
|
|
60 |
b = path[i+1] |
|
|
61 |
key = [con[1] for con in SCHEMA[a] if con[0] == b][0] |
|
|
62 |
inner_join_block = self.inner_join_template.format(b, f'{a}."{key}"', f'{b}."{key}"') |
|
|
63 |
#print(inner_join_block) |
|
|
64 |
inner_join_blocks.append(inner_join_block) |
|
|
65 |
|
|
|
66 |
inner_join_blocks = list(OrderedDict(zip(inner_join_blocks, repeat(None)))) |
|
|
67 |
new_inner_join = ' '.join(list(inner_join_blocks)) |
|
|
68 |
#print(new_inner_join) |
|
|
69 |
|
|
|
70 |
temp = 'FROM '.join([new_select, from_table]) |
|
|
71 |
temp = ' '.join([temp, new_inner_join]) |
|
|
72 |
temp = ' '.join(temp.split(' ')) + ' ' |
|
|
73 |
new_sql = 'WHERE'.join([temp, new_where]) |
|
|
74 |
return new_sql |
|
|
75 |
|
|
|
76 |
|
|
|
77 |
if __name__ == '__main__': |
|
|
78 |
convertor = SQL2SQL() |
|
|
79 |
sql = 'SELECT DEMOGRAPHIC."GENDER",DEMOGRAPHIC."INSURANCE" FROM DEMOGRAPHIC WHERE DEMOGRAPHIC."SUBJECT_ID" = "81923"' |
|
|
80 |
new_sql = convertor.translate(sql) |
|
|
81 |
print(sql) |
|
|
82 |
print(new_sql) |