|
a |
|
b/convert_mimicsql2sqlstar_dataset.py |
|
|
1 |
import sys |
|
|
2 |
sys.path.append('..') |
|
|
3 |
import json |
|
|
4 |
import os |
|
|
5 |
import argparse |
|
|
6 |
import pandas as pd |
|
|
7 |
from collections import Counter |
|
|
8 |
|
|
|
9 |
from mimicsql.evaluation.utils import query |
|
|
10 |
from sql2sql import SQL2SQL |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
def tokenize_sql(q): |
|
|
14 |
q = ' . '.join(q.split('.')) |
|
|
15 |
q = ' , '.join(q.split(',')) |
|
|
16 |
return q.split() |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
def convert_sql2sparql(filename='train.json', dataset_type='natural', execution=True): |
|
|
20 |
savedir = f'./dataset/mimicsqlstar/{dataset_type}/' |
|
|
21 |
datadir = f'./dataset/mimicsql/mimicsql_{dataset_type}/' |
|
|
22 |
data = [] |
|
|
23 |
with open(os.path.join(datadir, filename)) as json_file: |
|
|
24 |
for line in json_file: |
|
|
25 |
data.append(json.loads(line)) |
|
|
26 |
|
|
|
27 |
df = pd.DataFrame(data) |
|
|
28 |
|
|
|
29 |
if execution: |
|
|
30 |
print(f'LOAD original mimic_db ... {len(df)}') |
|
|
31 |
db_file = './mimicsql/evaluation/mimic_db/mimic.db' |
|
|
32 |
orimodel = query(db_file) |
|
|
33 |
print('DONE') |
|
|
34 |
|
|
|
35 |
print('LOAD mimicqlstar.db ...') |
|
|
36 |
db_file = './build_mimicsqlstar_db/mimicsqlstar.db' |
|
|
37 |
newmodel = query(db_file) |
|
|
38 |
print('DONE') |
|
|
39 |
|
|
|
40 |
sql2sql_convertor = SQL2SQL() |
|
|
41 |
|
|
|
42 |
correct = 0 |
|
|
43 |
newsqls = [] |
|
|
44 |
for i, sql in enumerate(df['sql']): |
|
|
45 |
sql_answer = [] |
|
|
46 |
newsql_answer = [] |
|
|
47 |
|
|
|
48 |
print("-" * 50) |
|
|
49 |
print(i, sql) |
|
|
50 |
|
|
|
51 |
if execution: |
|
|
52 |
sql_res = orimodel.execute_sql(sql.lower()).fetchall() |
|
|
53 |
for res in sql_res: |
|
|
54 |
val = '|' |
|
|
55 |
temp = [] |
|
|
56 |
for t in res: |
|
|
57 |
val += str(t) + '|\t\t|' |
|
|
58 |
temp.append(str(t)) |
|
|
59 |
print(val[:-1]) |
|
|
60 |
sql_answer.append(tuple(temp)) |
|
|
61 |
print() |
|
|
62 |
|
|
|
63 |
new_sql = sql2sql_convertor.translate(sql) |
|
|
64 |
|
|
|
65 |
print(i, new_sql) |
|
|
66 |
if execution: |
|
|
67 |
newsql_res = newmodel.execute_sql(new_sql.lower()).fetchall() |
|
|
68 |
for res in newsql_res: |
|
|
69 |
val = '|' |
|
|
70 |
temp = [] |
|
|
71 |
for t in res: |
|
|
72 |
val += str(t) + '|\t\t|' |
|
|
73 |
temp.append(str(t)) |
|
|
74 |
print(val[:-1]) |
|
|
75 |
newsql_answer.append(tuple(temp)) |
|
|
76 |
|
|
|
77 |
print(sql_answer, newsql_answer, set(sql_answer) == set(newsql_answer)) |
|
|
78 |
if set(sql_answer) == set(newsql_answer): |
|
|
79 |
correct += 1 |
|
|
80 |
else: |
|
|
81 |
print("[incorrect]") |
|
|
82 |
print() |
|
|
83 |
|
|
|
84 |
new_sql = new_sql.lower() |
|
|
85 |
newsql_tok = tokenize_sql(new_sql) |
|
|
86 |
newsqls.append({'sql': new_sql, 'sql_tok': newsql_tok}) |
|
|
87 |
|
|
|
88 |
if execution: |
|
|
89 |
print(f'[SQL2SQL] filenmae: {filename}, Answer Accuracy: {correct/len(df):.4f}') |
|
|
90 |
|
|
|
91 |
sql_data = [] |
|
|
92 |
for d, sql_d in zip(data, newsqls): |
|
|
93 |
d['sql'] = sql_d['sql'] |
|
|
94 |
d['sql_tok'] = sql_d['sql_tok'] |
|
|
95 |
sql_data.append(d) |
|
|
96 |
|
|
|
97 |
save_filename = os.path.join(savedir, filename) |
|
|
98 |
with open(save_filename, 'w') as json_file: |
|
|
99 |
for dic in sql_data: |
|
|
100 |
json.dump(dic, json_file) |
|
|
101 |
json_file.write('\n') |
|
|
102 |
|
|
|
103 |
print(f"Write to {save_filename}") |
|
|
104 |
|
|
|
105 |
|
|
|
106 |
def build_vocab(dataset_type='natural'): |
|
|
107 |
datadir = f'./dataset/mimicsqlstar/{dataset_type}/' |
|
|
108 |
filenames = ['train.json'] |
|
|
109 |
counter = Counter() |
|
|
110 |
for filename in filenames: |
|
|
111 |
with open(os.path.join(datadir, filename)) as json_file: |
|
|
112 |
for line in json_file: |
|
|
113 |
dic = json.loads(line) |
|
|
114 |
counter.update(dic['question_refine_tok']) |
|
|
115 |
counter.update(dic['sql_tok']) |
|
|
116 |
|
|
|
117 |
with open(os.path.join(datadir, 'vocab'), 'w') as f: |
|
|
118 |
for k, v in counter.most_common(): |
|
|
119 |
|
|
|
120 |
if len(k.split()) == 0: |
|
|
121 |
continue |
|
|
122 |
|
|
|
123 |
if k == ' ': |
|
|
124 |
continue |
|
|
125 |
f.write(f'{k} {v}\n') |
|
|
126 |
|
|
|
127 |
print(f'vocab builded: {len(counter)}') |
|
|
128 |
|
|
|
129 |
|
|
|
130 |
if __name__ == '__main__': |
|
|
131 |
parser = argparse.ArgumentParser(description='mimicsql to mimicsql*') |
|
|
132 |
parser.add_argument('--dataset_type', type=str, default='natural', choices=['natural','template']) |
|
|
133 |
parser.add_argument('--execution', default=False, type=lambda x: (str(x).lower() == 'true')) |
|
|
134 |
args = parser.parse_args() |
|
|
135 |
|
|
|
136 |
execution = args.execution |
|
|
137 |
dataset_type = args.dataset_type |
|
|
138 |
|
|
|
139 |
filenames = ['train.json', 'dev.json', 'test.json'] |
|
|
140 |
for filename in filenames: |
|
|
141 |
convert_sql2sparql(filename=filename, dataset_type=dataset_type, execution=execution) |
|
|
142 |
build_vocab(dataset_type=dataset_type) |