a b/convert_sql2sparql_dataset.py
1
import json
2
import os
3
import pandas as pd
4
import argparse
5
from rdflib import Graph
6
from collections import Counter
7
8
from mimicsql.evaluation.utils import query
9
from sql2sparql import SQL2SPARQL, split_entity
10
from evaluation_sparql import isequal
11
from build_mimicsparql_kg.build_complex_kg_from_mimicsqlstar_db import clean_text
12
13
14
def sparql_tokenize(sparql):
15
    sparql = split_entity(sparql)
16
    sparql = ' ^^'.join(sparql.split('^^'))
17
    sparql_tok = ' '.join(sparql.split(' '))
18
    return sparql_tok.split()
19
20
21
def convert_sql2sparql(complex=True, filename='train.json', dataset_type='natural', execution=True):
22
    if complex:
23
        savedir = f'./dataset/mimic_sparqlstar/{dataset_type}/'
24
        datadir = f'./dataset/mimicsqlstar/{dataset_type}/'
25
26
        sql2sparql = SQL2SPARQL(complex=complex, root='subject_id')
27
28
        if execution:
29
            print('LOAD  ... mimicqlstar.db')
30
            db_file = './build_mimicsqlstar_db/mimicsqlstar.db'
31
            model = query(db_file)
32
            print('DONE')
33
34
            print('LOAD KG ... mimic_kg')
35
            kg = Graph()
36
            kg.parse('./build_mimicsparql_kg/mimic_sparqlstar_kg.xml', format='xml', publicID='/')
37
            print('DONE')
38
39
    else:
40
        print(f'This dataset is Simple')
41
        savedir = f'./dataset/mimic_sparql/{dataset_type}/'
42
        datadir = f'./dataset/mimicsql/mimicsql_{dataset_type}/'
43
44
        sql2sparql = SQL2SPARQL(complex=complex, root='hadm_id')
45
46
        if execution:
47
            print('LOAD ... mimic.db')
48
            db_file = './mimicsql/evaluation/mimic_db/mimic.db'
49
            model = query(db_file)
50
            print('DONE')
51
52
            print('LOAD KG ... mimic_sparql_kg')
53
            kg = Graph()
54
            kg.parse('./build_mimicsparql_kg/mimic_sparql_kg.xml', format='xml', publicID='/')
55
            print('DONE')
56
57
    data = []
58
    with open(os.path.join(datadir, filename)) as json_file:
59
        for line in json_file:
60
            data.append(json.loads(line))
61
62
    df = pd.DataFrame(data)
63
64
    correct = 0
65
    sparqls = []
66
    for i, sql in enumerate(df['sql']):
67
        sql = sql.lower()
68
        sql_answer = []
69
        sparql_answer = []
70
71
        print("-" * 50)
72
        print(i, sql)
73
74
        if execution:
75
            sql_res = model.execute_sql(sql).fetchall()
76
            for res in sql_res:
77
                val = '|'
78
                temp = []
79
                for t in res:
80
                    val += str(t) + '|\t\t|'
81
                    temp.append(str(t))
82
                print(val[:-1])
83
                sql_answer.append(tuple(temp))
84
        print()
85
86
        sparql = sql2sparql.convert(sql)
87
        sparql = clean_text(sparql)
88
89
        print(i, sparql)
90
        if execution:
91
            sparql_res = kg.query(sparql)
92
            for res in sparql_res:
93
                val = '|'
94
                temp = []
95
                for t in res:
96
                    val += str(t.toPython()) + '|\t\t|'
97
                    temp.append(str(t.toPython()))
98
                print(val[:-1])
99
                sparql_answer.append(tuple(temp))
100
101
            print(sql_answer, sparql_answer, isequal(sql_answer, sparql_answer))
102
            if isequal(sql_answer, sparql_answer):
103
                correct += 1
104
            else:
105
                print("[incorrect]")
106
        print()
107
108
        sparql = sparql.lower()
109
        sparql_tok = sparql_tokenize(sparql)
110
        sparqls.append({'sql': sparql, 'sql_tok': sparql_tok})
111
112
    if execution:
113
        print(f'[SQL2SPARQL] filenmae: {filename}, Answer Accuracy: {correct/len(df):.4f}')
114
115
    sparql_data = []
116
    for d, sparql_d in zip(data, sparqls):
117
        d['sql'] = sparql_d['sql']
118
        d['sql_tok'] = sparql_d['sql_tok']
119
        sparql_data.append(d)
120
121
    save_filename = os.path.join(savedir, filename)
122
    with open(save_filename, 'w') as json_file:
123
        for dic in sparql_data:
124
            json.dump(dic, json_file)
125
            json_file.write('\n')
126
127
    print(f"Write to {save_filename}")
128
129
130
def build_vocab(complex=True, dataset_type='natural'):
131
    if complex:
132
        datadir = f'./dataset/mimic_sparqlstar/{dataset_type}'
133
    else:
134
        datadir = f'./dataset/mimic_sparql/{dataset_type}'
135
136
    filenames = ['train.json']
137
    counter = Counter()
138
    for filename in filenames:
139
        with open(os.path.join(datadir, filename)) as json_file:
140
            for line in json_file:
141
                dic = json.loads(line)
142
                counter.update(dic['question_refine_tok'])
143
                counter.update(dic['sql_tok'])
144
145
    with open(os.path.join(datadir, 'vocab'), 'w') as f:
146
        for k, v in counter.most_common():
147
148
            if len(k.split()) == 0:
149
                continue
150
151
            if k == ' ':
152
                continue
153
            f.write(f'{k} {v}\n')
154
155
    print(f'vocab builded: {len(counter)}')
156
157
158
if __name__ == '__main__':
159
    parser = argparse.ArgumentParser(description='mimicsql to mimic-sparql')
160
    parser.add_argument('--complex', default=False, type=lambda x: (str(x).lower() == 'true'))
161
    parser.add_argument('--dataset_type', type=str, default='natural', choices=['natural','template'])
162
    parser.add_argument('--execution', default=False, type=lambda x: (str(x).lower() == 'true'))
163
    args = parser.parse_args()
164
165
    execution = args.execution
166
    dataset_type = args.dataset_type
167
    complex = args.complex
168
169
    filenames = ['train.json', 'dev.json', 'test.json']
170
    for filename in filenames:
171
        convert_sql2sparql(complex=complex, filename=filename, dataset_type=dataset_type, execution=execution)
172
    build_vocab(complex=complex, dataset_type=dataset_type)