Diff of /evaluation_sparql.py [000000] .. [ab27bc]

Switch to unified view

a b/evaluation_sparql.py
1
import sys
2
sys.path.append('.')
3
sys.path.append('..')
4
import json
5
import os
6
import re
7
import pandas as pd
8
from rdflib import Graph
9
from sql2sparql import SQL2SPARQL, sparql_postprocessing, join_entity
10
from mimicsql.evaluation.utils import query
11
from build_mimicsparql_kg.build_complex_kg_from_mimicsqlstar_db import clean_text
12
13
14
def split_triples(sparql):
15
    try:
16
        select_part, where_part = sparql.split(' where ')
17
    except:
18
        print(sparql)
19
        select_part, where_part = sparql.split(' where ')[0], sparql.split(' where ')[-1]
20
21
    where_part = where_part.replace('{','').replace('}','')
22
    triple = [t.strip() for t in where_part.split('. ')]
23
    return select_part, [t for t in triple if len(t) != 0]
24
25
26
def none2zero(answer):
27
    if answer is None:
28
        return 0.0
29
30
    if type(answer) != str:
31
        return answer
32
33
    if answer.lower() == 'none':
34
        return 0.0
35
36
    try:
37
        answer = float(answer)
38
    except:
39
        pass
40
41
    return answer
42
43
44
def answer_normalization(answers):
45
    if len(answers) == 0:
46
        answers = [(0.0, )]
47
    return [tuple([none2zero(val) for val in answer]) for answer in answers]
48
49
50
def entity2value(entity):
51
    match = re.findall('/[a-z_\d]+/[a-z\d]+', entity)
52
    if len(match) > 0:
53
        return re.sub('/[a-z_\d]+/', '', entity)
54
    else:
55
        return entity
56
57
58
def replace_cond_val(sparql):
59
    try:
60
        where_part = re.findall('{[^{^}].*}', sparql)[0]
61
    except Exception as e:
62
        print(e)
63
        return sparql
64
65
    where_part = where_part.replace('{', '').replace('}', '').strip()
66
67
    try:
68
        ent_rel_cond = re.findall('\?[a-z_\d]+ </[a-z_\d]+> [^?][^.^]+', where_part)
69
        for m in ent_rel_cond:
70
            token = m.split()
71
            ent, rel = token[0], token[1]
72
            re_m = ' '.join([ent, rel, '<COND>'])
73
            sparql = sparql.replace(m, re_m)
74
75
        cond_rel_ent = re.findall('<[^?^ ]+> </[a-z_\d]+> \?[a-z_\d]+', where_part)
76
        for m in cond_rel_ent:
77
            cond, rel, ent = m.split()
78
            re_m = ' '.join(['<COND>', rel, ent])
79
            sparql = sparql.replace(m, re_m)
80
81
        filter_cond = re.findall('filter\( \?[a-z_\d]+ [<=>]+ [^?]+ \)', where_part)
82
        for m in filter_cond:
83
            ft, var, op, cond, _ = m.split()
84
            re_m = ' '.join([ft, var, op, '<COND>', _])
85
            sparql = sparql.replace(m, re_m)
86
87
    except Exception as e:
88
        print(e)
89
90
    return sparql
91
92
93
def isequal(sql_answer, sparql_answer): # list of tuple
94
    sql_answer = [row for row in sql_answer if 'None' not in row]
95
    sql_answer = [tuple([clean_text(a.lower()) if type(a) == str else a for a in row]) for row in sql_answer]
96
97
    sparql_answer = [tuple([entity2value(a) if type(a) == str else a for a in row]) for row in sparql_answer]
98
99
    if set(sql_answer) == set(sparql_answer):
100
        return True
101
102
    sparql_answer = answer_normalization(sparql_answer)
103
    sql_answer = answer_normalization(sql_answer)
104
105
    if set(sql_answer) == set(sparql_answer):
106
        return True
107
108
    return False
109
110
111
def check_no_cond_val(sparql):
112
    cond = []
113
    cond += re.findall('\^\^<http://', sparql) # value
114
    cond += re.findall('</[a-z_\d]+/[\d]+>', sparql) # entity
115
    cond += re.findall('"[a-z\d ]+"', sparql) # value
116
    cond += re.findall('filter', sparql) # fiter
117
    if len(cond) == 0:
118
        return True
119
    else:
120
        return False
121
122
123
def n_inner_join(x):
124
    return len(re.findall('inner join', x))
125
126
127
def compare_sql_and_spqral_pred():
128
    datadir = '../TREQS/mimicsql_data/mimicsql_natural/'
129
    filename = 'test.json'
130
    outputdir = '../TREQS/evaluation/generated_sql/'
131
    output_filename = 'output.json'
132
133
    covertor = SQL2SPARQL()
134
    print('LOAD output.json')
135
    sparql_preds = []
136
    sparql_golds = []
137
    with open(os.path.join(outputdir, output_filename)) as json_file:
138
        for line in json_file:
139
            dic = json.loads(line)
140
            sparql_preds.append(dic['sql_pred'])
141
            sparql_golds.append(dic['sql_gold'])
142
    print('DONE')
143
144
    data = []
145
    with open(os.path.join(datadir, filename)) as json_file:
146
        for line in json_file:
147
            data.append(json.loads(line))
148
149
    df = pd.DataFrame(data)
150
151
    print('LOAD DB ...')
152
    db_file = './evaluation/mimic_db/mimic.db'
153
    model = query(db_file)
154
    print('DONE')
155
156
    print('LOAD KG ...')
157
    kg = Graph()
158
    kg.parse('./evaluation/mimic_simple_kg.xml', format='xml', publicID='/')
159
    print('DONE')
160
161
    lf_permu_correct = 0
162
    lf_permu_cond_correct = 0
163
    cond_lf_correct = 0
164
    lf_correct = 0
165
    gold_correct = 0
166
    pred_correct = 0
167
    ablation_results = []
168
    for i, sql in enumerate(df['sql']):
169
        ablation_dic = {}
170
        sql = sql.lower()
171
172
        sql_answer = []
173
        sparql_pred_answer = []
174
        sparql_gold_answer = []
175
176
        print("-" * 50)
177
        print(i, sql)
178
179
        ablation_dic['n_inner'] = n_inner_join(sql)
180
        ablation_dic['n_hop'] = covertor.get_max_hop(sql)
181
182
        sql_res = model.execute_sql(sql).fetchall()
183
        for res in sql_res:
184
            val = '|'
185
            temp = []
186
            for t in res:
187
                val += str(t) + '|\t\t|'
188
                temp.append(str(t))
189
            print(val[:-1])
190
            sql_answer.append(tuple(temp))
191
        print()
192
193
        sparql_pred = sparql_preds[i]
194
        sparql_gold = sparql_golds[i]
195
196
        sparql_pred = sparql_postprocessing(sparql_pred)
197
        sparql_pred = join_entity(sparql_pred)
198
        sparql_gold = sparql_postprocessing(sparql_gold)
199
        sparql_gold = join_entity(sparql_gold)
200
201
202
        if sparql_pred.split() == sparql_gold.split():
203
            lf_correct += 1
204
            ablation_dic['lf_correct'] = 1
205
206
        print(sparql_gold)
207
        print(sparql_pred)
208
209
        cond_sp = replace_cond_val(sparql_pred)
210
        cond_sg = replace_cond_val(sparql_gold)
211
212
        if cond_sp.split() == cond_sg.split():
213
            cond_lf_correct += 1
214
            ablation_dic['cond_lf_correct'] = 1
215
            print(cond_sg)
216
            print(cond_sp)
217
218
        cond_sps, cond_spw = split_triples(cond_sp)
219
        cond_sgs, cond_sgw = split_triples(cond_sg)
220
221
        sps, spw = split_triples(sparql_pred)
222
        sgs, sgw = split_triples(sparql_gold)
223
224
        if cond_sps.split() == cond_sgs.split() and set(cond_spw) == set(cond_sgw):
225
            lf_permu_cond_correct += 1
226
227
        if sps.split() == sgs.split() and set(spw) == set(sgw):
228
            lf_permu_correct += 1
229
230
        print(i, sparql_gold)
231
        sparql_res = kg.query(sparql_gold)
232
        for res in sparql_res:
233
            val = '|'
234
            temp = []
235
            for t in res:
236
                val += str(t.toPython()) + '|\t\t|'
237
                temp.append(str(t.toPython()))
238
            print(val[:-1])
239
            sparql_gold_answer.append(tuple(temp))
240
        print(sql_answer, sparql_gold_answer, isequal(sql_answer, sparql_gold_answer))
241
242
        if isequal(sql_answer, sparql_gold_answer):
243
            gold_correct += 1
244
        else:
245
            print('sql gold false')
246
247
        print(i, sparql_pred)
248
249
        if check_no_cond_val(sparql_pred):
250
            print(f'[NO COND]: {sparql_pred}')
251
            print()
252
            ablation_results.append(ablation_dic)
253
            continue
254
255
        try:
256
            sparql_res = kg.query(sparql_pred)
257
            for res in sparql_res:
258
                val = '|'
259
                temp = []
260
                for t in res:
261
                    val += str(t.toPython()) + '|\t\t|'
262
                    temp.append(str(t.toPython()))
263
                print(val[:-1])
264
                sparql_pred_answer.append(tuple(temp))
265
            print(sql_answer, sparql_pred_answer, isequal(sql_answer, sparql_pred_answer))
266
267
            if isequal(sql_answer, sparql_pred_answer):
268
                ablation_dic['ex_correct'] = 1
269
                pred_correct += 1
270
271
        except:
272
            print(sparql_pred)
273
            print("syntax error")
274
275
        ablation_results.append(ablation_dic)
276
277
        print()
278
279
    print(f'[SQL2SPARQL] filenmae: {filename}, Answer Accuracy: {gold_correct / len(data):.4f}')
280
    print(f'[SQL2SPARQL] filenmae: {output_filename}, Answer Accuracy: {pred_correct / len(data):.4f}')
281
    print(f'[SQL2SPARQL] filenmae: {output_filename}, Logic Form Accuracy: {lf_correct / len(data):.4f}')
282
    print(f'[SQL2SPARQL] filenmae: {output_filename}, Cond Invariant Logic Form Accuracy: {cond_lf_correct / len(data):.4f}')
283
    print(f'[SQL2SPARQL] filenmae: {output_filename}, Logic Form Accuracy*: {lf_permu_correct / len(data):.4f}')
284
    print(f'[SQL2SPARQL] filenmae: {output_filename}, Cond Invariant Logic Form Accuracy*: {lf_permu_cond_correct / len(data):.4f}')
285
286
    df = pd.DataFrame(ablation_results)
287
    df.fillna(0, inplace=True)
288
    df.info()
289
    print(df['n_inner'].value_counts())
290
    print(df[df['n_inner'] == 0]['ex_correct'].sum())
291
    print(df[df['n_inner'] == 1]['ex_correct'].sum())
292
    print(df[df['n_inner'] == 2]['ex_correct'].sum())
293
294
    print('*'*50)
295
    print(df['n_hop'].value_counts())
296
    print(df[df['n_hop'] == 1]['ex_correct'].sum())
297
    print(df[df['n_hop'] == 2]['ex_correct'].sum())
298
    print(df[df['n_hop'] == 3]['ex_correct'].sum())
299
    print(df[df['n_hop'] == 4]['ex_correct'].sum())
300
301
    df.to_csv(f'./ablation_results_{output_filename}.csv')
302
303
304
if __name__ == '__main__':
305
    #compare_sql_and_spqral()
306
    compare_sql_and_spqral_pred()