[ab27bc]: / sql2sql.py

Download this file

83 lines (71 with data), 2.9 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import json
import os
import re
import numpy as np
import pandas as pd
import networkx as nx
from build_mimicsqlstar_db.schema_mimic import SCHEMA, MAP_WITH_MIMICSQL
from collections import OrderedDict
from itertools import repeat
class SQL2SQL:
def __init__(self):
self.schema_graph = nx.Graph()
for k, vs in SCHEMA.items():
for v in vs:
self.schema_graph.add_edge(k, v[0])
self.inner_join_template = 'INNER JOIN {} ON {} = {}'
def find_table(self, block):
cols = re.findall('[A-Z_]+[.]"[\dA-Z_]+"', block)
return [c.split('.')[0] for c in cols]
def from_caluse(self, new_select):
cols = re.findall('[A-Z_]+[.]"[\dA-Z_]+"', new_select)
for col in cols:
t, c = col.split('.')
return t
def cols_clause(self, block):
cols = re.findall('[A-Z_]+[.]"[\dA-Z_]+"', block)
#print(cols)
for col in cols:
t, c = col.split('.')
new_tc = MAP_WITH_MIMICSQL[t][c.replace('"', '')]
#print(new_tc)
block = block.replace(col, new_tc)
# print(col, new_tc)
return block
def translate(self, sql):
remain, where_block = sql.split('WHERE')
select_block, remain = remain.split('FROM')
new_select = self.cols_clause(select_block)
select_tables = self.find_table(new_select)
from_table = self.from_caluse(new_select)
new_where = self.cols_clause(where_block)
where_tables = self.find_table(new_where)
# print(new_select)
# print(from_table)
# print(new_where)
# print(where_tables)
inner_join_blocks = list()
tables = list(set(where_tables + select_tables))
for wt in tables:
path = nx.shortest_path(self.schema_graph, source=from_table.replace(' ', ''), target=wt.replace(' ', ''))
for i in range(len(path) - 1):
a = path[i]
b = path[i+1]
key = [con[1] for con in SCHEMA[a] if con[0] == b][0]
inner_join_block = self.inner_join_template.format(b, f'{a}."{key}"', f'{b}."{key}"')
#print(inner_join_block)
inner_join_blocks.append(inner_join_block)
inner_join_blocks = list(OrderedDict(zip(inner_join_blocks, repeat(None))))
new_inner_join = ' '.join(list(inner_join_blocks))
#print(new_inner_join)
temp = 'FROM '.join([new_select, from_table])
temp = ' '.join([temp, new_inner_join])
temp = ' '.join(temp.split(' ')) + ' '
new_sql = 'WHERE'.join([temp, new_where])
return new_sql
if __name__ == '__main__':
convertor = SQL2SQL()
sql = 'SELECT DEMOGRAPHIC."GENDER",DEMOGRAPHIC."INSURANCE" FROM DEMOGRAPHIC WHERE DEMOGRAPHIC."SUBJECT_ID" = "81923"'
new_sql = convertor.translate(sql)
print(sql)
print(new_sql)