--- a +++ b/sql2sparql.py @@ -0,0 +1,240 @@ +import re +import networkx as nx +from build_mimicsparql_kg.kg_complex_schema import SCHEMA_DTYPE, KG_SCHEMA +from build_mimicsparql_kg.kg_simple_schema import SIMPLE_KG_SCHEMA, SIMPLE_SCHEMA_DTYPE +from collections import OrderedDict +from itertools import repeat + + +def cond_syntax_fix(sparql): + match = re.findall('"?[^>]*"?\^\^', sparql) + for m in match: + sparql = sparql.replace(m, ' "' + m.split('^^')[0].strip().replace('"', '') + '"^^') + return sparql + + +def split_entity(sparql): + match = re.findall('</[a-z_]+/[\d.]+>',sparql) + for m in match: + val = re.findall('[\d.]+',m)[0] + + f, a = re.sub(val,' ', m).split() + repla = ' '.join([f, val, a]) + sparql = re.sub(m, repla, sparql) + return sparql + + +def join_entity(sparql): + match = re.findall('</[a-z_]+/ [\d\.]+ >[\. ]', sparql) + for m in match: + sparql = re.sub(m, ''.join(m.split() + [' ']), sparql) + return sparql + + +def clean_text(val): + if type(val) == str: + val = val.replace("\\", ' ') + return val + + +def value2entity(sparql): + where = sparql.split('where')[-1].split('filter')[0] + ms = re.findall('</[a-z_\d]+> [a-z\d.]+', where) + for m in ms: + rel, val = m.split() + val = val.replace('.','') + entity_val = f'{rel} {rel[:-1]}/{val}>.' + sparql = sparql.replace(m,entity_val) + return sparql + + +def sparql_postprocessing(sparql): + sparql = clean_text(sparql.lower()) + sparql = sparql.replace(' <stop>', '') + sparql = sparql.replace('/xmlschema#', '/XMLSchema#') + sparql = sparql.replace(' ^^<http://www', '^^<http://www') + sparql = cond_syntax_fix(sparql) + sparql = value2entity(sparql) + return sparql + + +class SQL2SPARQL: + def __init__(self, complex, root='subject_id'): + self.schema = KG_SCHEMA if complex else SIMPLE_KG_SCHEMA + self.schema_type = SCHEMA_DTYPE if complex else SIMPLE_SCHEMA_DTYPE + self.schema_type = {k.lower(): v for k, v in self.schema_type.items()} + self.schema_graph = nx.DiGraph() + for k, vs in self.schema.items(): + for v in vs: + self.schema_graph.add_edge(k.lower(), v.lower()) + self.agg_func = ['count', 'max', 'min', 'avg'] + self.sel_p = re.compile('"[^"]*"') + self.cond_p = re.compile('"[^"]*"|[=><]+') + self.sparql_agg_template = "select ( {AGG} ( {DISTINCT} ?{COL} ) as ?agg ) " + self.sparql_select_template = "select" + self.root = root + self.duplicate_columns = ['short_title', 'long_title', 'icd9_code'] + + def _replace_dulicate_column(self, sql): + for col in self.duplicate_columns: + tokens = [token.split('.') for token in re.findall(f'[a-z]+."{col}"', sql)] + for table, _ in tokens: + sql = re.sub(f'{table}."{col}"', f'"{table}_{col}"', sql) + return sql + + def get_max_hop(self, sql): + sql = self._replace_dulicate_column(sql) + remain, where_part = sql.split(' where ') + select_part, remain = remain.split(' from ') + + distinct_term = 'distinct' if 'distinct' in select_part else '' + agg_f = self._get_agg_func(select_part) + select_cols = self._get_select_col(select_part) + select_term = self._make_select_term(select_cols, agg_f, distinct_term) + + conds = self._get_conds(where_part) + triples, max_length = self._get_sparql_where_triples(select_cols, conds) + return max_length + + def convert(self, sql): + sql = self._replace_dulicate_column(sql) + sparql = self._parse_sql(sql) + return sparql + + def _parse_sql(self, sql): + remain, where_part = sql.split(' where ') + select_part, remain = remain.split(' from ') + + distinct_term = 'distinct' if 'distinct' in select_part else '' + agg_f = self._get_agg_func(select_part) + select_cols = self._get_select_col(select_part) + select_term = self._make_select_term(select_cols, agg_f, distinct_term) + + conds = self._get_conds(where_part) + where_term = self._make_where_term(select_cols, conds) + sql = f'{select_term} {where_term}' + return sql + + def _get_agg_func(self, select_part): + for f in self.agg_func: + if f in select_part.split(): + return f + return None + + def _get_select_col(self, select_part): + return [col.replace('"', '') for col in re.findall(self.sel_p, select_part)] + + def _get_conds(self, where_part): + tokens = re.findall(self.cond_p, where_part) + assert len(tokens) % 3 == 0 + conds = [[tokens[i].replace('"', '')] + tokens[i+1:i+3] for i in range(0, len(tokens), 3)] + return conds + + def _make_select_term(self, select_cols, agg_f, distinct): + term = self.sparql_select_template + if agg_f: + term = self.sparql_agg_template.format(AGG=agg_f, DISTINCT=distinct, COL=select_cols[0].lower()) + return term + + for col in select_cols: + term += f' ?{col.lower()}' + + return term + + def _make_where_term(self, sel_cols, conds): + triples, max_length = self._get_sparql_where_triples(sel_cols, conds) + term = f'where {{ {". ".join(triples)}. }}' + return term + + def _path2triples(self, path): + path.reverse() + t = [f"?{path.pop().lower()}"] + triples = [] + while len(path) > 0: + if len(t) == 0: + t.append(f"?{path[-1].lower()}") + path.pop() + elif len(t) == 1: + t.append(f"</{path[-1]}>") + elif len(t) == 2: + t.append(f"?{path[-1].lower()}") + elif len(t) == 3: + triples.append(t) + t = [] + else: + print('error') + return triples + + def _get_sparql_where_triples(self, sel_cols, conds): + max_length = 0 + triples = [] + filters = [] + for sel_col in sel_cols: + for cond_col, op, val in conds: + root2sel = nx.shortest_path(self.schema_graph, source=self.root, target=sel_col) + sel_triples = self._path2triples(root2sel) + + root2con = nx.shortest_path(self.schema_graph, source=self.root, target=cond_col) + if max_length < len(root2con): + max_length = len(root2con) + + if max_length < len(root2sel): + max_length = len(root2sel) + + if len(root2con) > 1: + con_triples = self._path2triples(root2con) + if op == '=': + con_triples = [self._fill_cond_value(t, cond_col, val) for t in con_triples] + else: + filters.append(f"""filter( ?{cond_col.lower()} {op} {val.replace('"', '')} )""") + a = [' '.join(t) for t in sel_triples] + b = [' '.join(t) for t in con_triples] + + else: + if op == '=': + sel_triples = [self._fill_cond_value(t, cond_col, val) for t in sel_triples] + else: + raise Exception() + a = [' '.join(t) for t in sel_triples] + b = [] + + b = [t for t in b if b not in a] + + triples += list(OrderedDict(zip(a + b, repeat(None)))) + + return list(OrderedDict(zip(triples, repeat(None)))) + filters, max_length + + def _fill_cond_value(self, t, cond_col, val): + sub = t[0] + rel = t[1] + cond = t[2] + + if cond.replace('?', '') == cond_col: + cond_name = cond.replace('?', '') + if self.schema_type[cond_name] == 'entity': + cond = f"""</{cond_name}/{val.replace('"', '')}>""" + else: + cond = f'{val}^^<{self.schema_type[cond_name]}>' + + elif sub.replace('?', '') == cond_col: + sub_name = sub.replace('?', '') + if self.schema_type[sub_name] == 'entity': + sub = f"""</{sub_name}/{val.replace('"', '')}>""" + else: + raise Exception() + + return sub, rel, cond + + +if __name__ == '__main__': + sql2sparql = SQL2SPARQL(True) + + sql = """SELECT MIN ( ADMISSIONS."AGE" ) + FROM ADMISSIONS + WHERE ADMISSIONS."DIAGNOSIS" = "S/P FALL" AND ADMISSIONS."ADMITYEAR" >= "2119" + """ + + sql = sql2sparql._replace_dulicate_column(sql) + print(sql) + result_sparql = sql2sparql._parse_sql(sql) + print(result_sparql)