Diff of /sql2sparql.py [000000] .. [ab27bc]

Switch to side-by-side view

--- a
+++ b/sql2sparql.py
@@ -0,0 +1,240 @@
+import re
+import networkx as nx
+from build_mimicsparql_kg.kg_complex_schema import SCHEMA_DTYPE, KG_SCHEMA
+from build_mimicsparql_kg.kg_simple_schema import SIMPLE_KG_SCHEMA, SIMPLE_SCHEMA_DTYPE
+from collections import OrderedDict
+from itertools import repeat
+
+
+def cond_syntax_fix(sparql):
+    match = re.findall('"?[^>]*"?\^\^', sparql)
+    for m in match:
+        sparql = sparql.replace(m, ' "' + m.split('^^')[0].strip().replace('"', '') + '"^^')
+    return sparql
+
+
+def split_entity(sparql):
+    match = re.findall('</[a-z_]+/[\d.]+>',sparql)
+    for m in match:
+        val = re.findall('[\d.]+',m)[0]
+
+        f, a = re.sub(val,' ', m).split()
+        repla = ' '.join([f, val, a])
+        sparql = re.sub(m, repla, sparql)
+    return sparql
+
+
+def join_entity(sparql):
+    match = re.findall('</[a-z_]+/ [\d\.]+ >[\. ]', sparql)
+    for m in match:
+        sparql = re.sub(m, ''.join(m.split() + [' ']), sparql)
+    return sparql
+
+
+def clean_text(val):
+    if type(val) == str:
+        val = val.replace("\\", ' ')
+    return val
+
+
+def value2entity(sparql):
+    where = sparql.split('where')[-1].split('filter')[0]
+    ms = re.findall('</[a-z_\d]+> [a-z\d.]+', where)
+    for m in ms:
+        rel, val = m.split()
+        val = val.replace('.','')
+        entity_val = f'{rel} {rel[:-1]}/{val}>.'
+        sparql = sparql.replace(m,entity_val)
+    return sparql
+
+
+def sparql_postprocessing(sparql):
+    sparql = clean_text(sparql.lower())
+    sparql = sparql.replace(' <stop>', '')
+    sparql = sparql.replace('/xmlschema#', '/XMLSchema#')
+    sparql = sparql.replace(' ^^<http://www', '^^<http://www')
+    sparql = cond_syntax_fix(sparql)
+    sparql = value2entity(sparql)
+    return sparql
+
+
+class SQL2SPARQL:
+    def __init__(self, complex, root='subject_id'):
+        self.schema = KG_SCHEMA if complex else SIMPLE_KG_SCHEMA
+        self.schema_type = SCHEMA_DTYPE if complex else SIMPLE_SCHEMA_DTYPE
+        self.schema_type = {k.lower(): v for k, v in self.schema_type.items()}
+        self.schema_graph = nx.DiGraph()
+        for k, vs in self.schema.items():
+            for v in vs:
+                self.schema_graph.add_edge(k.lower(), v.lower())
+        self.agg_func = ['count', 'max', 'min', 'avg']
+        self.sel_p = re.compile('"[^"]*"')
+        self.cond_p = re.compile('"[^"]*"|[=><]+')
+        self.sparql_agg_template = "select ( {AGG} ( {DISTINCT} ?{COL} ) as ?agg ) "
+        self.sparql_select_template = "select"
+        self.root = root
+        self.duplicate_columns = ['short_title', 'long_title', 'icd9_code']
+
+    def _replace_dulicate_column(self, sql):
+        for col in self.duplicate_columns:
+            tokens = [token.split('.') for token in re.findall(f'[a-z]+."{col}"', sql)]
+            for table, _ in tokens:
+                sql = re.sub(f'{table}."{col}"', f'"{table}_{col}"', sql)
+        return sql
+
+    def get_max_hop(self, sql):
+        sql = self._replace_dulicate_column(sql)
+        remain, where_part = sql.split(' where ')
+        select_part, remain = remain.split(' from ')
+
+        distinct_term = 'distinct' if 'distinct' in select_part else ''
+        agg_f = self._get_agg_func(select_part)
+        select_cols = self._get_select_col(select_part)
+        select_term = self._make_select_term(select_cols, agg_f, distinct_term)
+
+        conds = self._get_conds(where_part)
+        triples, max_length = self._get_sparql_where_triples(select_cols, conds)
+        return max_length
+
+    def convert(self, sql):
+        sql = self._replace_dulicate_column(sql)
+        sparql = self._parse_sql(sql)
+        return sparql
+
+    def _parse_sql(self, sql):
+        remain, where_part = sql.split(' where ')
+        select_part, remain = remain.split(' from ')
+
+        distinct_term = 'distinct' if 'distinct' in select_part else ''
+        agg_f = self._get_agg_func(select_part)
+        select_cols = self._get_select_col(select_part)
+        select_term = self._make_select_term(select_cols, agg_f, distinct_term)
+
+        conds = self._get_conds(where_part)
+        where_term = self._make_where_term(select_cols, conds)
+        sql = f'{select_term} {where_term}'
+        return sql
+
+    def _get_agg_func(self, select_part):
+        for f in self.agg_func:
+            if f in select_part.split():
+                return f
+        return None
+
+    def _get_select_col(self, select_part):
+        return [col.replace('"', '') for col in re.findall(self.sel_p, select_part)]
+
+    def _get_conds(self, where_part):
+        tokens = re.findall(self.cond_p, where_part)
+        assert len(tokens) % 3 == 0
+        conds = [[tokens[i].replace('"', '')] + tokens[i+1:i+3] for i in range(0, len(tokens), 3)]
+        return conds
+
+    def _make_select_term(self, select_cols, agg_f, distinct):
+        term = self.sparql_select_template
+        if agg_f:
+            term = self.sparql_agg_template.format(AGG=agg_f, DISTINCT=distinct, COL=select_cols[0].lower())
+            return term
+
+        for col in select_cols:
+            term += f' ?{col.lower()}'
+
+        return term
+
+    def _make_where_term(self, sel_cols, conds):
+        triples, max_length = self._get_sparql_where_triples(sel_cols, conds)
+        term = f'where {{ {". ".join(triples)}. }}'
+        return term
+
+    def _path2triples(self, path):
+        path.reverse()
+        t = [f"?{path.pop().lower()}"]
+        triples = []
+        while len(path) > 0:
+            if len(t) == 0:
+                t.append(f"?{path[-1].lower()}")
+                path.pop()
+            elif len(t) == 1:
+                t.append(f"</{path[-1]}>")
+            elif len(t) == 2:
+                t.append(f"?{path[-1].lower()}")
+            elif len(t) == 3:
+                triples.append(t)
+                t = []
+            else:
+                print('error')
+        return triples
+
+    def _get_sparql_where_triples(self, sel_cols, conds):
+        max_length = 0
+        triples = []
+        filters = []
+        for sel_col in sel_cols:
+            for cond_col, op, val in conds:
+                root2sel = nx.shortest_path(self.schema_graph, source=self.root, target=sel_col)
+                sel_triples = self._path2triples(root2sel)
+
+                root2con = nx.shortest_path(self.schema_graph, source=self.root, target=cond_col)
+                if max_length < len(root2con):
+                    max_length = len(root2con)
+
+                if max_length < len(root2sel):
+                    max_length = len(root2sel)
+
+                if len(root2con) > 1:
+                    con_triples = self._path2triples(root2con)
+                    if op == '=':
+                        con_triples = [self._fill_cond_value(t, cond_col, val) for t in con_triples]
+                    else:
+                        filters.append(f"""filter( ?{cond_col.lower()} {op} {val.replace('"', '')} )""")
+                    a = [' '.join(t) for t in sel_triples]
+                    b = [' '.join(t) for t in con_triples]
+
+                else:
+                    if op == '=':
+                        sel_triples = [self._fill_cond_value(t, cond_col, val) for t in sel_triples]
+                    else:
+                        raise Exception()
+                    a = [' '.join(t) for t in sel_triples]
+                    b = []
+
+                b = [t for t in b if b not in a]
+
+                triples += list(OrderedDict(zip(a + b, repeat(None))))
+
+        return list(OrderedDict(zip(triples, repeat(None)))) + filters, max_length
+
+    def _fill_cond_value(self, t, cond_col, val):
+        sub = t[0]
+        rel = t[1]
+        cond = t[2]
+
+        if cond.replace('?', '') == cond_col:
+            cond_name = cond.replace('?', '')
+            if self.schema_type[cond_name] == 'entity':
+                cond = f"""</{cond_name}/{val.replace('"', '')}>"""
+            else:
+                cond = f'{val}^^<{self.schema_type[cond_name]}>'
+
+        elif sub.replace('?', '') == cond_col:
+            sub_name = sub.replace('?', '')
+            if self.schema_type[sub_name] == 'entity':
+                sub = f"""</{sub_name}/{val.replace('"', '')}>"""
+            else:
+                raise Exception()
+
+        return sub, rel, cond
+
+
+if __name__ == '__main__':
+    sql2sparql = SQL2SPARQL(True)
+
+    sql = """SELECT MIN ( ADMISSIONS."AGE" ) 
+                FROM ADMISSIONS
+                WHERE ADMISSIONS."DIAGNOSIS" = "S/P FALL" AND ADMISSIONS."ADMITYEAR" >= "2119"
+    """
+
+    sql = sql2sparql._replace_dulicate_column(sql)
+    print(sql)
+    result_sparql = sql2sparql._parse_sql(sql)
+    print(result_sparql)