|
a |
|
b/evaluation_sparql.py |
|
|
1 |
import sys |
|
|
2 |
sys.path.append('.') |
|
|
3 |
sys.path.append('..') |
|
|
4 |
import json |
|
|
5 |
import os |
|
|
6 |
import re |
|
|
7 |
import pandas as pd |
|
|
8 |
from rdflib import Graph |
|
|
9 |
from sql2sparql import SQL2SPARQL, sparql_postprocessing, join_entity |
|
|
10 |
from mimicsql.evaluation.utils import query |
|
|
11 |
from build_mimicsparql_kg.build_complex_kg_from_mimicsqlstar_db import clean_text |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
def split_triples(sparql): |
|
|
15 |
try: |
|
|
16 |
select_part, where_part = sparql.split(' where ') |
|
|
17 |
except: |
|
|
18 |
print(sparql) |
|
|
19 |
select_part, where_part = sparql.split(' where ')[0], sparql.split(' where ')[-1] |
|
|
20 |
|
|
|
21 |
where_part = where_part.replace('{','').replace('}','') |
|
|
22 |
triple = [t.strip() for t in where_part.split('. ')] |
|
|
23 |
return select_part, [t for t in triple if len(t) != 0] |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
def none2zero(answer): |
|
|
27 |
if answer is None: |
|
|
28 |
return 0.0 |
|
|
29 |
|
|
|
30 |
if type(answer) != str: |
|
|
31 |
return answer |
|
|
32 |
|
|
|
33 |
if answer.lower() == 'none': |
|
|
34 |
return 0.0 |
|
|
35 |
|
|
|
36 |
try: |
|
|
37 |
answer = float(answer) |
|
|
38 |
except: |
|
|
39 |
pass |
|
|
40 |
|
|
|
41 |
return answer |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
def answer_normalization(answers): |
|
|
45 |
if len(answers) == 0: |
|
|
46 |
answers = [(0.0, )] |
|
|
47 |
return [tuple([none2zero(val) for val in answer]) for answer in answers] |
|
|
48 |
|
|
|
49 |
|
|
|
50 |
def entity2value(entity): |
|
|
51 |
match = re.findall('/[a-z_\d]+/[a-z\d]+', entity) |
|
|
52 |
if len(match) > 0: |
|
|
53 |
return re.sub('/[a-z_\d]+/', '', entity) |
|
|
54 |
else: |
|
|
55 |
return entity |
|
|
56 |
|
|
|
57 |
|
|
|
58 |
def replace_cond_val(sparql): |
|
|
59 |
try: |
|
|
60 |
where_part = re.findall('{[^{^}].*}', sparql)[0] |
|
|
61 |
except Exception as e: |
|
|
62 |
print(e) |
|
|
63 |
return sparql |
|
|
64 |
|
|
|
65 |
where_part = where_part.replace('{', '').replace('}', '').strip() |
|
|
66 |
|
|
|
67 |
try: |
|
|
68 |
ent_rel_cond = re.findall('\?[a-z_\d]+ </[a-z_\d]+> [^?][^.^]+', where_part) |
|
|
69 |
for m in ent_rel_cond: |
|
|
70 |
token = m.split() |
|
|
71 |
ent, rel = token[0], token[1] |
|
|
72 |
re_m = ' '.join([ent, rel, '<COND>']) |
|
|
73 |
sparql = sparql.replace(m, re_m) |
|
|
74 |
|
|
|
75 |
cond_rel_ent = re.findall('<[^?^ ]+> </[a-z_\d]+> \?[a-z_\d]+', where_part) |
|
|
76 |
for m in cond_rel_ent: |
|
|
77 |
cond, rel, ent = m.split() |
|
|
78 |
re_m = ' '.join(['<COND>', rel, ent]) |
|
|
79 |
sparql = sparql.replace(m, re_m) |
|
|
80 |
|
|
|
81 |
filter_cond = re.findall('filter\( \?[a-z_\d]+ [<=>]+ [^?]+ \)', where_part) |
|
|
82 |
for m in filter_cond: |
|
|
83 |
ft, var, op, cond, _ = m.split() |
|
|
84 |
re_m = ' '.join([ft, var, op, '<COND>', _]) |
|
|
85 |
sparql = sparql.replace(m, re_m) |
|
|
86 |
|
|
|
87 |
except Exception as e: |
|
|
88 |
print(e) |
|
|
89 |
|
|
|
90 |
return sparql |
|
|
91 |
|
|
|
92 |
|
|
|
93 |
def isequal(sql_answer, sparql_answer): # list of tuple |
|
|
94 |
sql_answer = [row for row in sql_answer if 'None' not in row] |
|
|
95 |
sql_answer = [tuple([clean_text(a.lower()) if type(a) == str else a for a in row]) for row in sql_answer] |
|
|
96 |
|
|
|
97 |
sparql_answer = [tuple([entity2value(a) if type(a) == str else a for a in row]) for row in sparql_answer] |
|
|
98 |
|
|
|
99 |
if set(sql_answer) == set(sparql_answer): |
|
|
100 |
return True |
|
|
101 |
|
|
|
102 |
sparql_answer = answer_normalization(sparql_answer) |
|
|
103 |
sql_answer = answer_normalization(sql_answer) |
|
|
104 |
|
|
|
105 |
if set(sql_answer) == set(sparql_answer): |
|
|
106 |
return True |
|
|
107 |
|
|
|
108 |
return False |
|
|
109 |
|
|
|
110 |
|
|
|
111 |
def check_no_cond_val(sparql): |
|
|
112 |
cond = [] |
|
|
113 |
cond += re.findall('\^\^<http://', sparql) # value |
|
|
114 |
cond += re.findall('</[a-z_\d]+/[\d]+>', sparql) # entity |
|
|
115 |
cond += re.findall('"[a-z\d ]+"', sparql) # value |
|
|
116 |
cond += re.findall('filter', sparql) # fiter |
|
|
117 |
if len(cond) == 0: |
|
|
118 |
return True |
|
|
119 |
else: |
|
|
120 |
return False |
|
|
121 |
|
|
|
122 |
|
|
|
123 |
def n_inner_join(x): |
|
|
124 |
return len(re.findall('inner join', x)) |
|
|
125 |
|
|
|
126 |
|
|
|
127 |
def compare_sql_and_spqral_pred(): |
|
|
128 |
datadir = '../TREQS/mimicsql_data/mimicsql_natural/' |
|
|
129 |
filename = 'test.json' |
|
|
130 |
outputdir = '../TREQS/evaluation/generated_sql/' |
|
|
131 |
output_filename = 'output.json' |
|
|
132 |
|
|
|
133 |
covertor = SQL2SPARQL() |
|
|
134 |
print('LOAD output.json') |
|
|
135 |
sparql_preds = [] |
|
|
136 |
sparql_golds = [] |
|
|
137 |
with open(os.path.join(outputdir, output_filename)) as json_file: |
|
|
138 |
for line in json_file: |
|
|
139 |
dic = json.loads(line) |
|
|
140 |
sparql_preds.append(dic['sql_pred']) |
|
|
141 |
sparql_golds.append(dic['sql_gold']) |
|
|
142 |
print('DONE') |
|
|
143 |
|
|
|
144 |
data = [] |
|
|
145 |
with open(os.path.join(datadir, filename)) as json_file: |
|
|
146 |
for line in json_file: |
|
|
147 |
data.append(json.loads(line)) |
|
|
148 |
|
|
|
149 |
df = pd.DataFrame(data) |
|
|
150 |
|
|
|
151 |
print('LOAD DB ...') |
|
|
152 |
db_file = './evaluation/mimic_db/mimic.db' |
|
|
153 |
model = query(db_file) |
|
|
154 |
print('DONE') |
|
|
155 |
|
|
|
156 |
print('LOAD KG ...') |
|
|
157 |
kg = Graph() |
|
|
158 |
kg.parse('./evaluation/mimic_simple_kg.xml', format='xml', publicID='/') |
|
|
159 |
print('DONE') |
|
|
160 |
|
|
|
161 |
lf_permu_correct = 0 |
|
|
162 |
lf_permu_cond_correct = 0 |
|
|
163 |
cond_lf_correct = 0 |
|
|
164 |
lf_correct = 0 |
|
|
165 |
gold_correct = 0 |
|
|
166 |
pred_correct = 0 |
|
|
167 |
ablation_results = [] |
|
|
168 |
for i, sql in enumerate(df['sql']): |
|
|
169 |
ablation_dic = {} |
|
|
170 |
sql = sql.lower() |
|
|
171 |
|
|
|
172 |
sql_answer = [] |
|
|
173 |
sparql_pred_answer = [] |
|
|
174 |
sparql_gold_answer = [] |
|
|
175 |
|
|
|
176 |
print("-" * 50) |
|
|
177 |
print(i, sql) |
|
|
178 |
|
|
|
179 |
ablation_dic['n_inner'] = n_inner_join(sql) |
|
|
180 |
ablation_dic['n_hop'] = covertor.get_max_hop(sql) |
|
|
181 |
|
|
|
182 |
sql_res = model.execute_sql(sql).fetchall() |
|
|
183 |
for res in sql_res: |
|
|
184 |
val = '|' |
|
|
185 |
temp = [] |
|
|
186 |
for t in res: |
|
|
187 |
val += str(t) + '|\t\t|' |
|
|
188 |
temp.append(str(t)) |
|
|
189 |
print(val[:-1]) |
|
|
190 |
sql_answer.append(tuple(temp)) |
|
|
191 |
print() |
|
|
192 |
|
|
|
193 |
sparql_pred = sparql_preds[i] |
|
|
194 |
sparql_gold = sparql_golds[i] |
|
|
195 |
|
|
|
196 |
sparql_pred = sparql_postprocessing(sparql_pred) |
|
|
197 |
sparql_pred = join_entity(sparql_pred) |
|
|
198 |
sparql_gold = sparql_postprocessing(sparql_gold) |
|
|
199 |
sparql_gold = join_entity(sparql_gold) |
|
|
200 |
|
|
|
201 |
|
|
|
202 |
if sparql_pred.split() == sparql_gold.split(): |
|
|
203 |
lf_correct += 1 |
|
|
204 |
ablation_dic['lf_correct'] = 1 |
|
|
205 |
|
|
|
206 |
print(sparql_gold) |
|
|
207 |
print(sparql_pred) |
|
|
208 |
|
|
|
209 |
cond_sp = replace_cond_val(sparql_pred) |
|
|
210 |
cond_sg = replace_cond_val(sparql_gold) |
|
|
211 |
|
|
|
212 |
if cond_sp.split() == cond_sg.split(): |
|
|
213 |
cond_lf_correct += 1 |
|
|
214 |
ablation_dic['cond_lf_correct'] = 1 |
|
|
215 |
print(cond_sg) |
|
|
216 |
print(cond_sp) |
|
|
217 |
|
|
|
218 |
cond_sps, cond_spw = split_triples(cond_sp) |
|
|
219 |
cond_sgs, cond_sgw = split_triples(cond_sg) |
|
|
220 |
|
|
|
221 |
sps, spw = split_triples(sparql_pred) |
|
|
222 |
sgs, sgw = split_triples(sparql_gold) |
|
|
223 |
|
|
|
224 |
if cond_sps.split() == cond_sgs.split() and set(cond_spw) == set(cond_sgw): |
|
|
225 |
lf_permu_cond_correct += 1 |
|
|
226 |
|
|
|
227 |
if sps.split() == sgs.split() and set(spw) == set(sgw): |
|
|
228 |
lf_permu_correct += 1 |
|
|
229 |
|
|
|
230 |
print(i, sparql_gold) |
|
|
231 |
sparql_res = kg.query(sparql_gold) |
|
|
232 |
for res in sparql_res: |
|
|
233 |
val = '|' |
|
|
234 |
temp = [] |
|
|
235 |
for t in res: |
|
|
236 |
val += str(t.toPython()) + '|\t\t|' |
|
|
237 |
temp.append(str(t.toPython())) |
|
|
238 |
print(val[:-1]) |
|
|
239 |
sparql_gold_answer.append(tuple(temp)) |
|
|
240 |
print(sql_answer, sparql_gold_answer, isequal(sql_answer, sparql_gold_answer)) |
|
|
241 |
|
|
|
242 |
if isequal(sql_answer, sparql_gold_answer): |
|
|
243 |
gold_correct += 1 |
|
|
244 |
else: |
|
|
245 |
print('sql gold false') |
|
|
246 |
|
|
|
247 |
print(i, sparql_pred) |
|
|
248 |
|
|
|
249 |
if check_no_cond_val(sparql_pred): |
|
|
250 |
print(f'[NO COND]: {sparql_pred}') |
|
|
251 |
print() |
|
|
252 |
ablation_results.append(ablation_dic) |
|
|
253 |
continue |
|
|
254 |
|
|
|
255 |
try: |
|
|
256 |
sparql_res = kg.query(sparql_pred) |
|
|
257 |
for res in sparql_res: |
|
|
258 |
val = '|' |
|
|
259 |
temp = [] |
|
|
260 |
for t in res: |
|
|
261 |
val += str(t.toPython()) + '|\t\t|' |
|
|
262 |
temp.append(str(t.toPython())) |
|
|
263 |
print(val[:-1]) |
|
|
264 |
sparql_pred_answer.append(tuple(temp)) |
|
|
265 |
print(sql_answer, sparql_pred_answer, isequal(sql_answer, sparql_pred_answer)) |
|
|
266 |
|
|
|
267 |
if isequal(sql_answer, sparql_pred_answer): |
|
|
268 |
ablation_dic['ex_correct'] = 1 |
|
|
269 |
pred_correct += 1 |
|
|
270 |
|
|
|
271 |
except: |
|
|
272 |
print(sparql_pred) |
|
|
273 |
print("syntax error") |
|
|
274 |
|
|
|
275 |
ablation_results.append(ablation_dic) |
|
|
276 |
|
|
|
277 |
print() |
|
|
278 |
|
|
|
279 |
print(f'[SQL2SPARQL] filenmae: {filename}, Answer Accuracy: {gold_correct / len(data):.4f}') |
|
|
280 |
print(f'[SQL2SPARQL] filenmae: {output_filename}, Answer Accuracy: {pred_correct / len(data):.4f}') |
|
|
281 |
print(f'[SQL2SPARQL] filenmae: {output_filename}, Logic Form Accuracy: {lf_correct / len(data):.4f}') |
|
|
282 |
print(f'[SQL2SPARQL] filenmae: {output_filename}, Cond Invariant Logic Form Accuracy: {cond_lf_correct / len(data):.4f}') |
|
|
283 |
print(f'[SQL2SPARQL] filenmae: {output_filename}, Logic Form Accuracy*: {lf_permu_correct / len(data):.4f}') |
|
|
284 |
print(f'[SQL2SPARQL] filenmae: {output_filename}, Cond Invariant Logic Form Accuracy*: {lf_permu_cond_correct / len(data):.4f}') |
|
|
285 |
|
|
|
286 |
df = pd.DataFrame(ablation_results) |
|
|
287 |
df.fillna(0, inplace=True) |
|
|
288 |
df.info() |
|
|
289 |
print(df['n_inner'].value_counts()) |
|
|
290 |
print(df[df['n_inner'] == 0]['ex_correct'].sum()) |
|
|
291 |
print(df[df['n_inner'] == 1]['ex_correct'].sum()) |
|
|
292 |
print(df[df['n_inner'] == 2]['ex_correct'].sum()) |
|
|
293 |
|
|
|
294 |
print('*'*50) |
|
|
295 |
print(df['n_hop'].value_counts()) |
|
|
296 |
print(df[df['n_hop'] == 1]['ex_correct'].sum()) |
|
|
297 |
print(df[df['n_hop'] == 2]['ex_correct'].sum()) |
|
|
298 |
print(df[df['n_hop'] == 3]['ex_correct'].sum()) |
|
|
299 |
print(df[df['n_hop'] == 4]['ex_correct'].sum()) |
|
|
300 |
|
|
|
301 |
df.to_csv(f'./ablation_results_{output_filename}.csv') |
|
|
302 |
|
|
|
303 |
|
|
|
304 |
if __name__ == '__main__': |
|
|
305 |
#compare_sql_and_spqral() |
|
|
306 |
compare_sql_and_spqral_pred() |