|
a |
|
b/sql2sparql.py |
|
|
1 |
import re |
|
|
2 |
import networkx as nx |
|
|
3 |
from build_mimicsparql_kg.kg_complex_schema import SCHEMA_DTYPE, KG_SCHEMA |
|
|
4 |
from build_mimicsparql_kg.kg_simple_schema import SIMPLE_KG_SCHEMA, SIMPLE_SCHEMA_DTYPE |
|
|
5 |
from collections import OrderedDict |
|
|
6 |
from itertools import repeat |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
def cond_syntax_fix(sparql): |
|
|
10 |
match = re.findall('"?[^>]*"?\^\^', sparql) |
|
|
11 |
for m in match: |
|
|
12 |
sparql = sparql.replace(m, ' "' + m.split('^^')[0].strip().replace('"', '') + '"^^') |
|
|
13 |
return sparql |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
def split_entity(sparql): |
|
|
17 |
match = re.findall('</[a-z_]+/[\d.]+>',sparql) |
|
|
18 |
for m in match: |
|
|
19 |
val = re.findall('[\d.]+',m)[0] |
|
|
20 |
|
|
|
21 |
f, a = re.sub(val,' ', m).split() |
|
|
22 |
repla = ' '.join([f, val, a]) |
|
|
23 |
sparql = re.sub(m, repla, sparql) |
|
|
24 |
return sparql |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
def join_entity(sparql): |
|
|
28 |
match = re.findall('</[a-z_]+/ [\d\.]+ >[\. ]', sparql) |
|
|
29 |
for m in match: |
|
|
30 |
sparql = re.sub(m, ''.join(m.split() + [' ']), sparql) |
|
|
31 |
return sparql |
|
|
32 |
|
|
|
33 |
|
|
|
34 |
def clean_text(val): |
|
|
35 |
if type(val) == str: |
|
|
36 |
val = val.replace("\\", ' ') |
|
|
37 |
return val |
|
|
38 |
|
|
|
39 |
|
|
|
40 |
def value2entity(sparql): |
|
|
41 |
where = sparql.split('where')[-1].split('filter')[0] |
|
|
42 |
ms = re.findall('</[a-z_\d]+> [a-z\d.]+', where) |
|
|
43 |
for m in ms: |
|
|
44 |
rel, val = m.split() |
|
|
45 |
val = val.replace('.','') |
|
|
46 |
entity_val = f'{rel} {rel[:-1]}/{val}>.' |
|
|
47 |
sparql = sparql.replace(m,entity_val) |
|
|
48 |
return sparql |
|
|
49 |
|
|
|
50 |
|
|
|
51 |
def sparql_postprocessing(sparql): |
|
|
52 |
sparql = clean_text(sparql.lower()) |
|
|
53 |
sparql = sparql.replace(' <stop>', '') |
|
|
54 |
sparql = sparql.replace('/xmlschema#', '/XMLSchema#') |
|
|
55 |
sparql = sparql.replace(' ^^<http://www', '^^<http://www') |
|
|
56 |
sparql = cond_syntax_fix(sparql) |
|
|
57 |
sparql = value2entity(sparql) |
|
|
58 |
return sparql |
|
|
59 |
|
|
|
60 |
|
|
|
61 |
class SQL2SPARQL: |
|
|
62 |
def __init__(self, complex, root='subject_id'): |
|
|
63 |
self.schema = KG_SCHEMA if complex else SIMPLE_KG_SCHEMA |
|
|
64 |
self.schema_type = SCHEMA_DTYPE if complex else SIMPLE_SCHEMA_DTYPE |
|
|
65 |
self.schema_type = {k.lower(): v for k, v in self.schema_type.items()} |
|
|
66 |
self.schema_graph = nx.DiGraph() |
|
|
67 |
for k, vs in self.schema.items(): |
|
|
68 |
for v in vs: |
|
|
69 |
self.schema_graph.add_edge(k.lower(), v.lower()) |
|
|
70 |
self.agg_func = ['count', 'max', 'min', 'avg'] |
|
|
71 |
self.sel_p = re.compile('"[^"]*"') |
|
|
72 |
self.cond_p = re.compile('"[^"]*"|[=><]+') |
|
|
73 |
self.sparql_agg_template = "select ( {AGG} ( {DISTINCT} ?{COL} ) as ?agg ) " |
|
|
74 |
self.sparql_select_template = "select" |
|
|
75 |
self.root = root |
|
|
76 |
self.duplicate_columns = ['short_title', 'long_title', 'icd9_code'] |
|
|
77 |
|
|
|
78 |
def _replace_dulicate_column(self, sql): |
|
|
79 |
for col in self.duplicate_columns: |
|
|
80 |
tokens = [token.split('.') for token in re.findall(f'[a-z]+."{col}"', sql)] |
|
|
81 |
for table, _ in tokens: |
|
|
82 |
sql = re.sub(f'{table}."{col}"', f'"{table}_{col}"', sql) |
|
|
83 |
return sql |
|
|
84 |
|
|
|
85 |
def get_max_hop(self, sql): |
|
|
86 |
sql = self._replace_dulicate_column(sql) |
|
|
87 |
remain, where_part = sql.split(' where ') |
|
|
88 |
select_part, remain = remain.split(' from ') |
|
|
89 |
|
|
|
90 |
distinct_term = 'distinct' if 'distinct' in select_part else '' |
|
|
91 |
agg_f = self._get_agg_func(select_part) |
|
|
92 |
select_cols = self._get_select_col(select_part) |
|
|
93 |
select_term = self._make_select_term(select_cols, agg_f, distinct_term) |
|
|
94 |
|
|
|
95 |
conds = self._get_conds(where_part) |
|
|
96 |
triples, max_length = self._get_sparql_where_triples(select_cols, conds) |
|
|
97 |
return max_length |
|
|
98 |
|
|
|
99 |
def convert(self, sql): |
|
|
100 |
sql = self._replace_dulicate_column(sql) |
|
|
101 |
sparql = self._parse_sql(sql) |
|
|
102 |
return sparql |
|
|
103 |
|
|
|
104 |
def _parse_sql(self, sql): |
|
|
105 |
remain, where_part = sql.split(' where ') |
|
|
106 |
select_part, remain = remain.split(' from ') |
|
|
107 |
|
|
|
108 |
distinct_term = 'distinct' if 'distinct' in select_part else '' |
|
|
109 |
agg_f = self._get_agg_func(select_part) |
|
|
110 |
select_cols = self._get_select_col(select_part) |
|
|
111 |
select_term = self._make_select_term(select_cols, agg_f, distinct_term) |
|
|
112 |
|
|
|
113 |
conds = self._get_conds(where_part) |
|
|
114 |
where_term = self._make_where_term(select_cols, conds) |
|
|
115 |
sql = f'{select_term} {where_term}' |
|
|
116 |
return sql |
|
|
117 |
|
|
|
118 |
def _get_agg_func(self, select_part): |
|
|
119 |
for f in self.agg_func: |
|
|
120 |
if f in select_part.split(): |
|
|
121 |
return f |
|
|
122 |
return None |
|
|
123 |
|
|
|
124 |
def _get_select_col(self, select_part): |
|
|
125 |
return [col.replace('"', '') for col in re.findall(self.sel_p, select_part)] |
|
|
126 |
|
|
|
127 |
def _get_conds(self, where_part): |
|
|
128 |
tokens = re.findall(self.cond_p, where_part) |
|
|
129 |
assert len(tokens) % 3 == 0 |
|
|
130 |
conds = [[tokens[i].replace('"', '')] + tokens[i+1:i+3] for i in range(0, len(tokens), 3)] |
|
|
131 |
return conds |
|
|
132 |
|
|
|
133 |
def _make_select_term(self, select_cols, agg_f, distinct): |
|
|
134 |
term = self.sparql_select_template |
|
|
135 |
if agg_f: |
|
|
136 |
term = self.sparql_agg_template.format(AGG=agg_f, DISTINCT=distinct, COL=select_cols[0].lower()) |
|
|
137 |
return term |
|
|
138 |
|
|
|
139 |
for col in select_cols: |
|
|
140 |
term += f' ?{col.lower()}' |
|
|
141 |
|
|
|
142 |
return term |
|
|
143 |
|
|
|
144 |
def _make_where_term(self, sel_cols, conds): |
|
|
145 |
triples, max_length = self._get_sparql_where_triples(sel_cols, conds) |
|
|
146 |
term = f'where {{ {". ".join(triples)}. }}' |
|
|
147 |
return term |
|
|
148 |
|
|
|
149 |
def _path2triples(self, path): |
|
|
150 |
path.reverse() |
|
|
151 |
t = [f"?{path.pop().lower()}"] |
|
|
152 |
triples = [] |
|
|
153 |
while len(path) > 0: |
|
|
154 |
if len(t) == 0: |
|
|
155 |
t.append(f"?{path[-1].lower()}") |
|
|
156 |
path.pop() |
|
|
157 |
elif len(t) == 1: |
|
|
158 |
t.append(f"</{path[-1]}>") |
|
|
159 |
elif len(t) == 2: |
|
|
160 |
t.append(f"?{path[-1].lower()}") |
|
|
161 |
elif len(t) == 3: |
|
|
162 |
triples.append(t) |
|
|
163 |
t = [] |
|
|
164 |
else: |
|
|
165 |
print('error') |
|
|
166 |
return triples |
|
|
167 |
|
|
|
168 |
def _get_sparql_where_triples(self, sel_cols, conds): |
|
|
169 |
max_length = 0 |
|
|
170 |
triples = [] |
|
|
171 |
filters = [] |
|
|
172 |
for sel_col in sel_cols: |
|
|
173 |
for cond_col, op, val in conds: |
|
|
174 |
root2sel = nx.shortest_path(self.schema_graph, source=self.root, target=sel_col) |
|
|
175 |
sel_triples = self._path2triples(root2sel) |
|
|
176 |
|
|
|
177 |
root2con = nx.shortest_path(self.schema_graph, source=self.root, target=cond_col) |
|
|
178 |
if max_length < len(root2con): |
|
|
179 |
max_length = len(root2con) |
|
|
180 |
|
|
|
181 |
if max_length < len(root2sel): |
|
|
182 |
max_length = len(root2sel) |
|
|
183 |
|
|
|
184 |
if len(root2con) > 1: |
|
|
185 |
con_triples = self._path2triples(root2con) |
|
|
186 |
if op == '=': |
|
|
187 |
con_triples = [self._fill_cond_value(t, cond_col, val) for t in con_triples] |
|
|
188 |
else: |
|
|
189 |
filters.append(f"""filter( ?{cond_col.lower()} {op} {val.replace('"', '')} )""") |
|
|
190 |
a = [' '.join(t) for t in sel_triples] |
|
|
191 |
b = [' '.join(t) for t in con_triples] |
|
|
192 |
|
|
|
193 |
else: |
|
|
194 |
if op == '=': |
|
|
195 |
sel_triples = [self._fill_cond_value(t, cond_col, val) for t in sel_triples] |
|
|
196 |
else: |
|
|
197 |
raise Exception() |
|
|
198 |
a = [' '.join(t) for t in sel_triples] |
|
|
199 |
b = [] |
|
|
200 |
|
|
|
201 |
b = [t for t in b if b not in a] |
|
|
202 |
|
|
|
203 |
triples += list(OrderedDict(zip(a + b, repeat(None)))) |
|
|
204 |
|
|
|
205 |
return list(OrderedDict(zip(triples, repeat(None)))) + filters, max_length |
|
|
206 |
|
|
|
207 |
def _fill_cond_value(self, t, cond_col, val): |
|
|
208 |
sub = t[0] |
|
|
209 |
rel = t[1] |
|
|
210 |
cond = t[2] |
|
|
211 |
|
|
|
212 |
if cond.replace('?', '') == cond_col: |
|
|
213 |
cond_name = cond.replace('?', '') |
|
|
214 |
if self.schema_type[cond_name] == 'entity': |
|
|
215 |
cond = f"""</{cond_name}/{val.replace('"', '')}>""" |
|
|
216 |
else: |
|
|
217 |
cond = f'{val}^^<{self.schema_type[cond_name]}>' |
|
|
218 |
|
|
|
219 |
elif sub.replace('?', '') == cond_col: |
|
|
220 |
sub_name = sub.replace('?', '') |
|
|
221 |
if self.schema_type[sub_name] == 'entity': |
|
|
222 |
sub = f"""</{sub_name}/{val.replace('"', '')}>""" |
|
|
223 |
else: |
|
|
224 |
raise Exception() |
|
|
225 |
|
|
|
226 |
return sub, rel, cond |
|
|
227 |
|
|
|
228 |
|
|
|
229 |
if __name__ == '__main__': |
|
|
230 |
sql2sparql = SQL2SPARQL(True) |
|
|
231 |
|
|
|
232 |
sql = """SELECT MIN ( ADMISSIONS."AGE" ) |
|
|
233 |
FROM ADMISSIONS |
|
|
234 |
WHERE ADMISSIONS."DIAGNOSIS" = "S/P FALL" AND ADMISSIONS."ADMITYEAR" >= "2119" |
|
|
235 |
""" |
|
|
236 |
|
|
|
237 |
sql = sql2sparql._replace_dulicate_column(sql) |
|
|
238 |
print(sql) |
|
|
239 |
result_sparql = sql2sparql._parse_sql(sql) |
|
|
240 |
print(result_sparql) |