Diff of /sql2sql.py [000000] .. [ab27bc]

Switch to side-by-side view

--- a
+++ b/sql2sql.py
@@ -0,0 +1,82 @@
+import json
+import os
+import re
+import numpy as np
+import pandas as pd
+import networkx as nx
+from build_mimicsqlstar_db.schema_mimic import SCHEMA, MAP_WITH_MIMICSQL
+from collections import OrderedDict
+from itertools import repeat
+
+
+class SQL2SQL:
+    def __init__(self):
+        self.schema_graph = nx.Graph()
+        for k, vs in SCHEMA.items():
+            for v in vs:
+                self.schema_graph.add_edge(k, v[0])
+        self.inner_join_template = 'INNER JOIN {} ON {} = {}'
+
+    def find_table(self, block):
+        cols = re.findall('[A-Z_]+[.]"[\dA-Z_]+"', block)
+        return [c.split('.')[0] for c in cols]
+
+    def from_caluse(self, new_select):
+        cols = re.findall('[A-Z_]+[.]"[\dA-Z_]+"', new_select)
+        for col in cols:
+            t, c = col.split('.')
+            return t
+
+    def cols_clause(self, block):
+        cols = re.findall('[A-Z_]+[.]"[\dA-Z_]+"', block)
+        #print(cols)
+        for col in cols:
+            t, c = col.split('.')
+            new_tc = MAP_WITH_MIMICSQL[t][c.replace('"', '')]
+            #print(new_tc)
+            block = block.replace(col, new_tc)
+            # print(col, new_tc)
+        return block
+
+    def translate(self, sql):
+        remain, where_block = sql.split('WHERE')
+        select_block, remain = remain.split('FROM')
+
+        new_select = self.cols_clause(select_block)
+        select_tables = self.find_table(new_select)
+        from_table = self.from_caluse(new_select)
+        new_where = self.cols_clause(where_block)
+        where_tables = self.find_table(new_where)
+        # print(new_select)
+        # print(from_table)
+        # print(new_where)
+        # print(where_tables)
+        inner_join_blocks = list()
+        tables = list(set(where_tables + select_tables))
+        for wt in tables:
+            path = nx.shortest_path(self.schema_graph, source=from_table.replace(' ', ''), target=wt.replace(' ', ''))
+            for i in range(len(path) - 1):
+                a = path[i]
+                b = path[i+1]
+                key = [con[1] for con in SCHEMA[a] if con[0] == b][0]
+                inner_join_block = self.inner_join_template.format(b, f'{a}."{key}"', f'{b}."{key}"')
+                #print(inner_join_block)
+                inner_join_blocks.append(inner_join_block)
+
+        inner_join_blocks = list(OrderedDict(zip(inner_join_blocks, repeat(None))))
+        new_inner_join = ' '.join(list(inner_join_blocks))
+        #print(new_inner_join)
+
+        temp = 'FROM '.join([new_select, from_table])
+        temp = ' '.join([temp, new_inner_join])
+        temp = ' '.join(temp.split(' ')) + ' '
+        new_sql = 'WHERE'.join([temp, new_where])
+        return new_sql
+
+
+if __name__ == '__main__':
+    convertor = SQL2SQL()
+    sql = 'SELECT DEMOGRAPHIC."GENDER",DEMOGRAPHIC."INSURANCE" FROM DEMOGRAPHIC WHERE DEMOGRAPHIC."SUBJECT_ID" = "81923"'
+    new_sql = convertor.translate(sql)
+    print(sql)
+    print(new_sql)