Diff of /evaluation_sparql.py [000000] .. [ab27bc]

Switch to side-by-side view

--- a
+++ b/evaluation_sparql.py
@@ -0,0 +1,306 @@
+import sys
+sys.path.append('.')
+sys.path.append('..')
+import json
+import os
+import re
+import pandas as pd
+from rdflib import Graph
+from sql2sparql import SQL2SPARQL, sparql_postprocessing, join_entity
+from mimicsql.evaluation.utils import query
+from build_mimicsparql_kg.build_complex_kg_from_mimicsqlstar_db import clean_text
+
+
+def split_triples(sparql):
+    try:
+        select_part, where_part = sparql.split(' where ')
+    except:
+        print(sparql)
+        select_part, where_part = sparql.split(' where ')[0], sparql.split(' where ')[-1]
+
+    where_part = where_part.replace('{','').replace('}','')
+    triple = [t.strip() for t in where_part.split('. ')]
+    return select_part, [t for t in triple if len(t) != 0]
+
+
+def none2zero(answer):
+    if answer is None:
+        return 0.0
+
+    if type(answer) != str:
+        return answer
+
+    if answer.lower() == 'none':
+        return 0.0
+
+    try:
+        answer = float(answer)
+    except:
+        pass
+
+    return answer
+
+
+def answer_normalization(answers):
+    if len(answers) == 0:
+        answers = [(0.0, )]
+    return [tuple([none2zero(val) for val in answer]) for answer in answers]
+
+
+def entity2value(entity):
+    match = re.findall('/[a-z_\d]+/[a-z\d]+', entity)
+    if len(match) > 0:
+        return re.sub('/[a-z_\d]+/', '', entity)
+    else:
+        return entity
+
+
+def replace_cond_val(sparql):
+    try:
+        where_part = re.findall('{[^{^}].*}', sparql)[0]
+    except Exception as e:
+        print(e)
+        return sparql
+
+    where_part = where_part.replace('{', '').replace('}', '').strip()
+
+    try:
+        ent_rel_cond = re.findall('\?[a-z_\d]+ </[a-z_\d]+> [^?][^.^]+', where_part)
+        for m in ent_rel_cond:
+            token = m.split()
+            ent, rel = token[0], token[1]
+            re_m = ' '.join([ent, rel, '<COND>'])
+            sparql = sparql.replace(m, re_m)
+
+        cond_rel_ent = re.findall('<[^?^ ]+> </[a-z_\d]+> \?[a-z_\d]+', where_part)
+        for m in cond_rel_ent:
+            cond, rel, ent = m.split()
+            re_m = ' '.join(['<COND>', rel, ent])
+            sparql = sparql.replace(m, re_m)
+
+        filter_cond = re.findall('filter\( \?[a-z_\d]+ [<=>]+ [^?]+ \)', where_part)
+        for m in filter_cond:
+            ft, var, op, cond, _ = m.split()
+            re_m = ' '.join([ft, var, op, '<COND>', _])
+            sparql = sparql.replace(m, re_m)
+
+    except Exception as e:
+        print(e)
+
+    return sparql
+
+
+def isequal(sql_answer, sparql_answer): # list of tuple
+    sql_answer = [row for row in sql_answer if 'None' not in row]
+    sql_answer = [tuple([clean_text(a.lower()) if type(a) == str else a for a in row]) for row in sql_answer]
+
+    sparql_answer = [tuple([entity2value(a) if type(a) == str else a for a in row]) for row in sparql_answer]
+
+    if set(sql_answer) == set(sparql_answer):
+        return True
+
+    sparql_answer = answer_normalization(sparql_answer)
+    sql_answer = answer_normalization(sql_answer)
+
+    if set(sql_answer) == set(sparql_answer):
+        return True
+
+    return False
+
+
+def check_no_cond_val(sparql):
+    cond = []
+    cond += re.findall('\^\^<http://', sparql) # value
+    cond += re.findall('</[a-z_\d]+/[\d]+>', sparql) # entity
+    cond += re.findall('"[a-z\d ]+"', sparql) # value
+    cond += re.findall('filter', sparql) # fiter
+    if len(cond) == 0:
+        return True
+    else:
+        return False
+
+
+def n_inner_join(x):
+    return len(re.findall('inner join', x))
+
+
+def compare_sql_and_spqral_pred():
+    datadir = '../TREQS/mimicsql_data/mimicsql_natural/'
+    filename = 'test.json'
+    outputdir = '../TREQS/evaluation/generated_sql/'
+    output_filename = 'output.json'
+
+    covertor = SQL2SPARQL()
+    print('LOAD output.json')
+    sparql_preds = []
+    sparql_golds = []
+    with open(os.path.join(outputdir, output_filename)) as json_file:
+        for line in json_file:
+            dic = json.loads(line)
+            sparql_preds.append(dic['sql_pred'])
+            sparql_golds.append(dic['sql_gold'])
+    print('DONE')
+
+    data = []
+    with open(os.path.join(datadir, filename)) as json_file:
+        for line in json_file:
+            data.append(json.loads(line))
+
+    df = pd.DataFrame(data)
+
+    print('LOAD DB ...')
+    db_file = './evaluation/mimic_db/mimic.db'
+    model = query(db_file)
+    print('DONE')
+
+    print('LOAD KG ...')
+    kg = Graph()
+    kg.parse('./evaluation/mimic_simple_kg.xml', format='xml', publicID='/')
+    print('DONE')
+
+    lf_permu_correct = 0
+    lf_permu_cond_correct = 0
+    cond_lf_correct = 0
+    lf_correct = 0
+    gold_correct = 0
+    pred_correct = 0
+    ablation_results = []
+    for i, sql in enumerate(df['sql']):
+        ablation_dic = {}
+        sql = sql.lower()
+
+        sql_answer = []
+        sparql_pred_answer = []
+        sparql_gold_answer = []
+
+        print("-" * 50)
+        print(i, sql)
+
+        ablation_dic['n_inner'] = n_inner_join(sql)
+        ablation_dic['n_hop'] = covertor.get_max_hop(sql)
+
+        sql_res = model.execute_sql(sql).fetchall()
+        for res in sql_res:
+            val = '|'
+            temp = []
+            for t in res:
+                val += str(t) + '|\t\t|'
+                temp.append(str(t))
+            print(val[:-1])
+            sql_answer.append(tuple(temp))
+        print()
+
+        sparql_pred = sparql_preds[i]
+        sparql_gold = sparql_golds[i]
+
+        sparql_pred = sparql_postprocessing(sparql_pred)
+        sparql_pred = join_entity(sparql_pred)
+        sparql_gold = sparql_postprocessing(sparql_gold)
+        sparql_gold = join_entity(sparql_gold)
+
+
+        if sparql_pred.split() == sparql_gold.split():
+            lf_correct += 1
+            ablation_dic['lf_correct'] = 1
+
+        print(sparql_gold)
+        print(sparql_pred)
+
+        cond_sp = replace_cond_val(sparql_pred)
+        cond_sg = replace_cond_val(sparql_gold)
+
+        if cond_sp.split() == cond_sg.split():
+            cond_lf_correct += 1
+            ablation_dic['cond_lf_correct'] = 1
+            print(cond_sg)
+            print(cond_sp)
+
+        cond_sps, cond_spw = split_triples(cond_sp)
+        cond_sgs, cond_sgw = split_triples(cond_sg)
+
+        sps, spw = split_triples(sparql_pred)
+        sgs, sgw = split_triples(sparql_gold)
+
+        if cond_sps.split() == cond_sgs.split() and set(cond_spw) == set(cond_sgw):
+            lf_permu_cond_correct += 1
+
+        if sps.split() == sgs.split() and set(spw) == set(sgw):
+            lf_permu_correct += 1
+
+        print(i, sparql_gold)
+        sparql_res = kg.query(sparql_gold)
+        for res in sparql_res:
+            val = '|'
+            temp = []
+            for t in res:
+                val += str(t.toPython()) + '|\t\t|'
+                temp.append(str(t.toPython()))
+            print(val[:-1])
+            sparql_gold_answer.append(tuple(temp))
+        print(sql_answer, sparql_gold_answer, isequal(sql_answer, sparql_gold_answer))
+
+        if isequal(sql_answer, sparql_gold_answer):
+            gold_correct += 1
+        else:
+            print('sql gold false')
+
+        print(i, sparql_pred)
+
+        if check_no_cond_val(sparql_pred):
+            print(f'[NO COND]: {sparql_pred}')
+            print()
+            ablation_results.append(ablation_dic)
+            continue
+
+        try:
+            sparql_res = kg.query(sparql_pred)
+            for res in sparql_res:
+                val = '|'
+                temp = []
+                for t in res:
+                    val += str(t.toPython()) + '|\t\t|'
+                    temp.append(str(t.toPython()))
+                print(val[:-1])
+                sparql_pred_answer.append(tuple(temp))
+            print(sql_answer, sparql_pred_answer, isequal(sql_answer, sparql_pred_answer))
+
+            if isequal(sql_answer, sparql_pred_answer):
+                ablation_dic['ex_correct'] = 1
+                pred_correct += 1
+
+        except:
+            print(sparql_pred)
+            print("syntax error")
+
+        ablation_results.append(ablation_dic)
+
+        print()
+
+    print(f'[SQL2SPARQL] filenmae: {filename}, Answer Accuracy: {gold_correct / len(data):.4f}')
+    print(f'[SQL2SPARQL] filenmae: {output_filename}, Answer Accuracy: {pred_correct / len(data):.4f}')
+    print(f'[SQL2SPARQL] filenmae: {output_filename}, Logic Form Accuracy: {lf_correct / len(data):.4f}')
+    print(f'[SQL2SPARQL] filenmae: {output_filename}, Cond Invariant Logic Form Accuracy: {cond_lf_correct / len(data):.4f}')
+    print(f'[SQL2SPARQL] filenmae: {output_filename}, Logic Form Accuracy*: {lf_permu_correct / len(data):.4f}')
+    print(f'[SQL2SPARQL] filenmae: {output_filename}, Cond Invariant Logic Form Accuracy*: {lf_permu_cond_correct / len(data):.4f}')
+
+    df = pd.DataFrame(ablation_results)
+    df.fillna(0, inplace=True)
+    df.info()
+    print(df['n_inner'].value_counts())
+    print(df[df['n_inner'] == 0]['ex_correct'].sum())
+    print(df[df['n_inner'] == 1]['ex_correct'].sum())
+    print(df[df['n_inner'] == 2]['ex_correct'].sum())
+
+    print('*'*50)
+    print(df['n_hop'].value_counts())
+    print(df[df['n_hop'] == 1]['ex_correct'].sum())
+    print(df[df['n_hop'] == 2]['ex_correct'].sum())
+    print(df[df['n_hop'] == 3]['ex_correct'].sum())
+    print(df[df['n_hop'] == 4]['ex_correct'].sum())
+
+    df.to_csv(f'./ablation_results_{output_filename}.csv')
+
+
+if __name__ == '__main__':
+    #compare_sql_and_spqral()
+    compare_sql_and_spqral_pred()