Diff of /sql2sparql.py [000000] .. [ab27bc]

Switch to unified view

a b/sql2sparql.py
1
import re
2
import networkx as nx
3
from build_mimicsparql_kg.kg_complex_schema import SCHEMA_DTYPE, KG_SCHEMA
4
from build_mimicsparql_kg.kg_simple_schema import SIMPLE_KG_SCHEMA, SIMPLE_SCHEMA_DTYPE
5
from collections import OrderedDict
6
from itertools import repeat
7
8
9
def cond_syntax_fix(sparql):
10
    match = re.findall('"?[^>]*"?\^\^', sparql)
11
    for m in match:
12
        sparql = sparql.replace(m, ' "' + m.split('^^')[0].strip().replace('"', '') + '"^^')
13
    return sparql
14
15
16
def split_entity(sparql):
17
    match = re.findall('</[a-z_]+/[\d.]+>',sparql)
18
    for m in match:
19
        val = re.findall('[\d.]+',m)[0]
20
21
        f, a = re.sub(val,' ', m).split()
22
        repla = ' '.join([f, val, a])
23
        sparql = re.sub(m, repla, sparql)
24
    return sparql
25
26
27
def join_entity(sparql):
28
    match = re.findall('</[a-z_]+/ [\d\.]+ >[\. ]', sparql)
29
    for m in match:
30
        sparql = re.sub(m, ''.join(m.split() + [' ']), sparql)
31
    return sparql
32
33
34
def clean_text(val):
35
    if type(val) == str:
36
        val = val.replace("\\", ' ')
37
    return val
38
39
40
def value2entity(sparql):
41
    where = sparql.split('where')[-1].split('filter')[0]
42
    ms = re.findall('</[a-z_\d]+> [a-z\d.]+', where)
43
    for m in ms:
44
        rel, val = m.split()
45
        val = val.replace('.','')
46
        entity_val = f'{rel} {rel[:-1]}/{val}>.'
47
        sparql = sparql.replace(m,entity_val)
48
    return sparql
49
50
51
def sparql_postprocessing(sparql):
52
    sparql = clean_text(sparql.lower())
53
    sparql = sparql.replace(' <stop>', '')
54
    sparql = sparql.replace('/xmlschema#', '/XMLSchema#')
55
    sparql = sparql.replace(' ^^<http://www', '^^<http://www')
56
    sparql = cond_syntax_fix(sparql)
57
    sparql = value2entity(sparql)
58
    return sparql
59
60
61
class SQL2SPARQL:
62
    def __init__(self, complex, root='subject_id'):
63
        self.schema = KG_SCHEMA if complex else SIMPLE_KG_SCHEMA
64
        self.schema_type = SCHEMA_DTYPE if complex else SIMPLE_SCHEMA_DTYPE
65
        self.schema_type = {k.lower(): v for k, v in self.schema_type.items()}
66
        self.schema_graph = nx.DiGraph()
67
        for k, vs in self.schema.items():
68
            for v in vs:
69
                self.schema_graph.add_edge(k.lower(), v.lower())
70
        self.agg_func = ['count', 'max', 'min', 'avg']
71
        self.sel_p = re.compile('"[^"]*"')
72
        self.cond_p = re.compile('"[^"]*"|[=><]+')
73
        self.sparql_agg_template = "select ( {AGG} ( {DISTINCT} ?{COL} ) as ?agg ) "
74
        self.sparql_select_template = "select"
75
        self.root = root
76
        self.duplicate_columns = ['short_title', 'long_title', 'icd9_code']
77
78
    def _replace_dulicate_column(self, sql):
79
        for col in self.duplicate_columns:
80
            tokens = [token.split('.') for token in re.findall(f'[a-z]+."{col}"', sql)]
81
            for table, _ in tokens:
82
                sql = re.sub(f'{table}."{col}"', f'"{table}_{col}"', sql)
83
        return sql
84
85
    def get_max_hop(self, sql):
86
        sql = self._replace_dulicate_column(sql)
87
        remain, where_part = sql.split(' where ')
88
        select_part, remain = remain.split(' from ')
89
90
        distinct_term = 'distinct' if 'distinct' in select_part else ''
91
        agg_f = self._get_agg_func(select_part)
92
        select_cols = self._get_select_col(select_part)
93
        select_term = self._make_select_term(select_cols, agg_f, distinct_term)
94
95
        conds = self._get_conds(where_part)
96
        triples, max_length = self._get_sparql_where_triples(select_cols, conds)
97
        return max_length
98
99
    def convert(self, sql):
100
        sql = self._replace_dulicate_column(sql)
101
        sparql = self._parse_sql(sql)
102
        return sparql
103
104
    def _parse_sql(self, sql):
105
        remain, where_part = sql.split(' where ')
106
        select_part, remain = remain.split(' from ')
107
108
        distinct_term = 'distinct' if 'distinct' in select_part else ''
109
        agg_f = self._get_agg_func(select_part)
110
        select_cols = self._get_select_col(select_part)
111
        select_term = self._make_select_term(select_cols, agg_f, distinct_term)
112
113
        conds = self._get_conds(where_part)
114
        where_term = self._make_where_term(select_cols, conds)
115
        sql = f'{select_term} {where_term}'
116
        return sql
117
118
    def _get_agg_func(self, select_part):
119
        for f in self.agg_func:
120
            if f in select_part.split():
121
                return f
122
        return None
123
124
    def _get_select_col(self, select_part):
125
        return [col.replace('"', '') for col in re.findall(self.sel_p, select_part)]
126
127
    def _get_conds(self, where_part):
128
        tokens = re.findall(self.cond_p, where_part)
129
        assert len(tokens) % 3 == 0
130
        conds = [[tokens[i].replace('"', '')] + tokens[i+1:i+3] for i in range(0, len(tokens), 3)]
131
        return conds
132
133
    def _make_select_term(self, select_cols, agg_f, distinct):
134
        term = self.sparql_select_template
135
        if agg_f:
136
            term = self.sparql_agg_template.format(AGG=agg_f, DISTINCT=distinct, COL=select_cols[0].lower())
137
            return term
138
139
        for col in select_cols:
140
            term += f' ?{col.lower()}'
141
142
        return term
143
144
    def _make_where_term(self, sel_cols, conds):
145
        triples, max_length = self._get_sparql_where_triples(sel_cols, conds)
146
        term = f'where {{ {". ".join(triples)}. }}'
147
        return term
148
149
    def _path2triples(self, path):
150
        path.reverse()
151
        t = [f"?{path.pop().lower()}"]
152
        triples = []
153
        while len(path) > 0:
154
            if len(t) == 0:
155
                t.append(f"?{path[-1].lower()}")
156
                path.pop()
157
            elif len(t) == 1:
158
                t.append(f"</{path[-1]}>")
159
            elif len(t) == 2:
160
                t.append(f"?{path[-1].lower()}")
161
            elif len(t) == 3:
162
                triples.append(t)
163
                t = []
164
            else:
165
                print('error')
166
        return triples
167
168
    def _get_sparql_where_triples(self, sel_cols, conds):
169
        max_length = 0
170
        triples = []
171
        filters = []
172
        for sel_col in sel_cols:
173
            for cond_col, op, val in conds:
174
                root2sel = nx.shortest_path(self.schema_graph, source=self.root, target=sel_col)
175
                sel_triples = self._path2triples(root2sel)
176
177
                root2con = nx.shortest_path(self.schema_graph, source=self.root, target=cond_col)
178
                if max_length < len(root2con):
179
                    max_length = len(root2con)
180
181
                if max_length < len(root2sel):
182
                    max_length = len(root2sel)
183
184
                if len(root2con) > 1:
185
                    con_triples = self._path2triples(root2con)
186
                    if op == '=':
187
                        con_triples = [self._fill_cond_value(t, cond_col, val) for t in con_triples]
188
                    else:
189
                        filters.append(f"""filter( ?{cond_col.lower()} {op} {val.replace('"', '')} )""")
190
                    a = [' '.join(t) for t in sel_triples]
191
                    b = [' '.join(t) for t in con_triples]
192
193
                else:
194
                    if op == '=':
195
                        sel_triples = [self._fill_cond_value(t, cond_col, val) for t in sel_triples]
196
                    else:
197
                        raise Exception()
198
                    a = [' '.join(t) for t in sel_triples]
199
                    b = []
200
201
                b = [t for t in b if b not in a]
202
203
                triples += list(OrderedDict(zip(a + b, repeat(None))))
204
205
        return list(OrderedDict(zip(triples, repeat(None)))) + filters, max_length
206
207
    def _fill_cond_value(self, t, cond_col, val):
208
        sub = t[0]
209
        rel = t[1]
210
        cond = t[2]
211
212
        if cond.replace('?', '') == cond_col:
213
            cond_name = cond.replace('?', '')
214
            if self.schema_type[cond_name] == 'entity':
215
                cond = f"""</{cond_name}/{val.replace('"', '')}>"""
216
            else:
217
                cond = f'{val}^^<{self.schema_type[cond_name]}>'
218
219
        elif sub.replace('?', '') == cond_col:
220
            sub_name = sub.replace('?', '')
221
            if self.schema_type[sub_name] == 'entity':
222
                sub = f"""</{sub_name}/{val.replace('"', '')}>"""
223
            else:
224
                raise Exception()
225
226
        return sub, rel, cond
227
228
229
if __name__ == '__main__':
230
    sql2sparql = SQL2SPARQL(True)
231
232
    sql = """SELECT MIN ( ADMISSIONS."AGE" ) 
233
                FROM ADMISSIONS
234
                WHERE ADMISSIONS."DIAGNOSIS" = "S/P FALL" AND ADMISSIONS."ADMITYEAR" >= "2119"
235
    """
236
237
    sql = sql2sparql._replace_dulicate_column(sql)
238
    print(sql)
239
    result_sparql = sql2sparql._parse_sql(sql)
240
    print(result_sparql)