--- a +++ b/sql2sql.py @@ -0,0 +1,82 @@ +import json +import os +import re +import numpy as np +import pandas as pd +import networkx as nx +from build_mimicsqlstar_db.schema_mimic import SCHEMA, MAP_WITH_MIMICSQL +from collections import OrderedDict +from itertools import repeat + + +class SQL2SQL: + def __init__(self): + self.schema_graph = nx.Graph() + for k, vs in SCHEMA.items(): + for v in vs: + self.schema_graph.add_edge(k, v[0]) + self.inner_join_template = 'INNER JOIN {} ON {} = {}' + + def find_table(self, block): + cols = re.findall('[A-Z_]+[.]"[\dA-Z_]+"', block) + return [c.split('.')[0] for c in cols] + + def from_caluse(self, new_select): + cols = re.findall('[A-Z_]+[.]"[\dA-Z_]+"', new_select) + for col in cols: + t, c = col.split('.') + return t + + def cols_clause(self, block): + cols = re.findall('[A-Z_]+[.]"[\dA-Z_]+"', block) + #print(cols) + for col in cols: + t, c = col.split('.') + new_tc = MAP_WITH_MIMICSQL[t][c.replace('"', '')] + #print(new_tc) + block = block.replace(col, new_tc) + # print(col, new_tc) + return block + + def translate(self, sql): + remain, where_block = sql.split('WHERE') + select_block, remain = remain.split('FROM') + + new_select = self.cols_clause(select_block) + select_tables = self.find_table(new_select) + from_table = self.from_caluse(new_select) + new_where = self.cols_clause(where_block) + where_tables = self.find_table(new_where) + # print(new_select) + # print(from_table) + # print(new_where) + # print(where_tables) + inner_join_blocks = list() + tables = list(set(where_tables + select_tables)) + for wt in tables: + path = nx.shortest_path(self.schema_graph, source=from_table.replace(' ', ''), target=wt.replace(' ', '')) + for i in range(len(path) - 1): + a = path[i] + b = path[i+1] + key = [con[1] for con in SCHEMA[a] if con[0] == b][0] + inner_join_block = self.inner_join_template.format(b, f'{a}."{key}"', f'{b}."{key}"') + #print(inner_join_block) + inner_join_blocks.append(inner_join_block) + + inner_join_blocks = list(OrderedDict(zip(inner_join_blocks, repeat(None)))) + new_inner_join = ' '.join(list(inner_join_blocks)) + #print(new_inner_join) + + temp = 'FROM '.join([new_select, from_table]) + temp = ' '.join([temp, new_inner_join]) + temp = ' '.join(temp.split(' ')) + ' ' + new_sql = 'WHERE'.join([temp, new_where]) + return new_sql + + +if __name__ == '__main__': + convertor = SQL2SQL() + sql = 'SELECT DEMOGRAPHIC."GENDER",DEMOGRAPHIC."INSURANCE" FROM DEMOGRAPHIC WHERE DEMOGRAPHIC."SUBJECT_ID" = "81923"' + new_sql = convertor.translate(sql) + print(sql) + print(new_sql)