|
a |
|
b/convert_sql2sparql_dataset.py |
|
|
1 |
import json |
|
|
2 |
import os |
|
|
3 |
import pandas as pd |
|
|
4 |
import argparse |
|
|
5 |
from rdflib import Graph |
|
|
6 |
from collections import Counter |
|
|
7 |
|
|
|
8 |
from mimicsql.evaluation.utils import query |
|
|
9 |
from sql2sparql import SQL2SPARQL, split_entity |
|
|
10 |
from evaluation_sparql import isequal |
|
|
11 |
from build_mimicsparql_kg.build_complex_kg_from_mimicsqlstar_db import clean_text |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
def sparql_tokenize(sparql): |
|
|
15 |
sparql = split_entity(sparql) |
|
|
16 |
sparql = ' ^^'.join(sparql.split('^^')) |
|
|
17 |
sparql_tok = ' '.join(sparql.split(' ')) |
|
|
18 |
return sparql_tok.split() |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
def convert_sql2sparql(complex=True, filename='train.json', dataset_type='natural', execution=True): |
|
|
22 |
if complex: |
|
|
23 |
savedir = f'./dataset/mimic_sparqlstar/{dataset_type}/' |
|
|
24 |
datadir = f'./dataset/mimicsqlstar/{dataset_type}/' |
|
|
25 |
|
|
|
26 |
sql2sparql = SQL2SPARQL(complex=complex, root='subject_id') |
|
|
27 |
|
|
|
28 |
if execution: |
|
|
29 |
print('LOAD ... mimicqlstar.db') |
|
|
30 |
db_file = './build_mimicsqlstar_db/mimicsqlstar.db' |
|
|
31 |
model = query(db_file) |
|
|
32 |
print('DONE') |
|
|
33 |
|
|
|
34 |
print('LOAD KG ... mimic_kg') |
|
|
35 |
kg = Graph() |
|
|
36 |
kg.parse('./build_mimicsparql_kg/mimic_sparqlstar_kg.xml', format='xml', publicID='/') |
|
|
37 |
print('DONE') |
|
|
38 |
|
|
|
39 |
else: |
|
|
40 |
print(f'This dataset is Simple') |
|
|
41 |
savedir = f'./dataset/mimic_sparql/{dataset_type}/' |
|
|
42 |
datadir = f'./dataset/mimicsql/mimicsql_{dataset_type}/' |
|
|
43 |
|
|
|
44 |
sql2sparql = SQL2SPARQL(complex=complex, root='hadm_id') |
|
|
45 |
|
|
|
46 |
if execution: |
|
|
47 |
print('LOAD ... mimic.db') |
|
|
48 |
db_file = './mimicsql/evaluation/mimic_db/mimic.db' |
|
|
49 |
model = query(db_file) |
|
|
50 |
print('DONE') |
|
|
51 |
|
|
|
52 |
print('LOAD KG ... mimic_sparql_kg') |
|
|
53 |
kg = Graph() |
|
|
54 |
kg.parse('./build_mimicsparql_kg/mimic_sparql_kg.xml', format='xml', publicID='/') |
|
|
55 |
print('DONE') |
|
|
56 |
|
|
|
57 |
data = [] |
|
|
58 |
with open(os.path.join(datadir, filename)) as json_file: |
|
|
59 |
for line in json_file: |
|
|
60 |
data.append(json.loads(line)) |
|
|
61 |
|
|
|
62 |
df = pd.DataFrame(data) |
|
|
63 |
|
|
|
64 |
correct = 0 |
|
|
65 |
sparqls = [] |
|
|
66 |
for i, sql in enumerate(df['sql']): |
|
|
67 |
sql = sql.lower() |
|
|
68 |
sql_answer = [] |
|
|
69 |
sparql_answer = [] |
|
|
70 |
|
|
|
71 |
print("-" * 50) |
|
|
72 |
print(i, sql) |
|
|
73 |
|
|
|
74 |
if execution: |
|
|
75 |
sql_res = model.execute_sql(sql).fetchall() |
|
|
76 |
for res in sql_res: |
|
|
77 |
val = '|' |
|
|
78 |
temp = [] |
|
|
79 |
for t in res: |
|
|
80 |
val += str(t) + '|\t\t|' |
|
|
81 |
temp.append(str(t)) |
|
|
82 |
print(val[:-1]) |
|
|
83 |
sql_answer.append(tuple(temp)) |
|
|
84 |
print() |
|
|
85 |
|
|
|
86 |
sparql = sql2sparql.convert(sql) |
|
|
87 |
sparql = clean_text(sparql) |
|
|
88 |
|
|
|
89 |
print(i, sparql) |
|
|
90 |
if execution: |
|
|
91 |
sparql_res = kg.query(sparql) |
|
|
92 |
for res in sparql_res: |
|
|
93 |
val = '|' |
|
|
94 |
temp = [] |
|
|
95 |
for t in res: |
|
|
96 |
val += str(t.toPython()) + '|\t\t|' |
|
|
97 |
temp.append(str(t.toPython())) |
|
|
98 |
print(val[:-1]) |
|
|
99 |
sparql_answer.append(tuple(temp)) |
|
|
100 |
|
|
|
101 |
print(sql_answer, sparql_answer, isequal(sql_answer, sparql_answer)) |
|
|
102 |
if isequal(sql_answer, sparql_answer): |
|
|
103 |
correct += 1 |
|
|
104 |
else: |
|
|
105 |
print("[incorrect]") |
|
|
106 |
print() |
|
|
107 |
|
|
|
108 |
sparql = sparql.lower() |
|
|
109 |
sparql_tok = sparql_tokenize(sparql) |
|
|
110 |
sparqls.append({'sql': sparql, 'sql_tok': sparql_tok}) |
|
|
111 |
|
|
|
112 |
if execution: |
|
|
113 |
print(f'[SQL2SPARQL] filenmae: {filename}, Answer Accuracy: {correct/len(df):.4f}') |
|
|
114 |
|
|
|
115 |
sparql_data = [] |
|
|
116 |
for d, sparql_d in zip(data, sparqls): |
|
|
117 |
d['sql'] = sparql_d['sql'] |
|
|
118 |
d['sql_tok'] = sparql_d['sql_tok'] |
|
|
119 |
sparql_data.append(d) |
|
|
120 |
|
|
|
121 |
save_filename = os.path.join(savedir, filename) |
|
|
122 |
with open(save_filename, 'w') as json_file: |
|
|
123 |
for dic in sparql_data: |
|
|
124 |
json.dump(dic, json_file) |
|
|
125 |
json_file.write('\n') |
|
|
126 |
|
|
|
127 |
print(f"Write to {save_filename}") |
|
|
128 |
|
|
|
129 |
|
|
|
130 |
def build_vocab(complex=True, dataset_type='natural'): |
|
|
131 |
if complex: |
|
|
132 |
datadir = f'./dataset/mimic_sparqlstar/{dataset_type}' |
|
|
133 |
else: |
|
|
134 |
datadir = f'./dataset/mimic_sparql/{dataset_type}' |
|
|
135 |
|
|
|
136 |
filenames = ['train.json'] |
|
|
137 |
counter = Counter() |
|
|
138 |
for filename in filenames: |
|
|
139 |
with open(os.path.join(datadir, filename)) as json_file: |
|
|
140 |
for line in json_file: |
|
|
141 |
dic = json.loads(line) |
|
|
142 |
counter.update(dic['question_refine_tok']) |
|
|
143 |
counter.update(dic['sql_tok']) |
|
|
144 |
|
|
|
145 |
with open(os.path.join(datadir, 'vocab'), 'w') as f: |
|
|
146 |
for k, v in counter.most_common(): |
|
|
147 |
|
|
|
148 |
if len(k.split()) == 0: |
|
|
149 |
continue |
|
|
150 |
|
|
|
151 |
if k == ' ': |
|
|
152 |
continue |
|
|
153 |
f.write(f'{k} {v}\n') |
|
|
154 |
|
|
|
155 |
print(f'vocab builded: {len(counter)}') |
|
|
156 |
|
|
|
157 |
|
|
|
158 |
if __name__ == '__main__': |
|
|
159 |
parser = argparse.ArgumentParser(description='mimicsql to mimic-sparql') |
|
|
160 |
parser.add_argument('--complex', default=False, type=lambda x: (str(x).lower() == 'true')) |
|
|
161 |
parser.add_argument('--dataset_type', type=str, default='natural', choices=['natural','template']) |
|
|
162 |
parser.add_argument('--execution', default=False, type=lambda x: (str(x).lower() == 'true')) |
|
|
163 |
args = parser.parse_args() |
|
|
164 |
|
|
|
165 |
execution = args.execution |
|
|
166 |
dataset_type = args.dataset_type |
|
|
167 |
complex = args.complex |
|
|
168 |
|
|
|
169 |
filenames = ['train.json', 'dev.json', 'test.json'] |
|
|
170 |
for filename in filenames: |
|
|
171 |
convert_sql2sparql(complex=complex, filename=filename, dataset_type=dataset_type, execution=execution) |
|
|
172 |
build_vocab(complex=complex, dataset_type=dataset_type) |