Switch to unified view

a b/convert_mimicsql2sqlstar_dataset.py
1
import sys
2
sys.path.append('..')
3
import json
4
import os
5
import argparse
6
import pandas as pd
7
from collections import Counter
8
9
from mimicsql.evaluation.utils import query
10
from sql2sql import SQL2SQL
11
12
13
def tokenize_sql(q):
14
    q = ' . '.join(q.split('.'))
15
    q = ' , '.join(q.split(','))
16
    return q.split()
17
18
19
def convert_sql2sparql(filename='train.json', dataset_type='natural', execution=True):
20
    savedir = f'./dataset/mimicsqlstar/{dataset_type}/'
21
    datadir = f'./dataset/mimicsql/mimicsql_{dataset_type}/'
22
    data = []
23
    with open(os.path.join(datadir, filename)) as json_file:
24
        for line in json_file:
25
            data.append(json.loads(line))
26
27
    df = pd.DataFrame(data)
28
29
    if execution:
30
        print(f'LOAD original mimic_db ... {len(df)}')
31
        db_file = './mimicsql/evaluation/mimic_db/mimic.db'
32
        orimodel = query(db_file)
33
        print('DONE')
34
35
        print('LOAD mimicqlstar.db ...')
36
        db_file = './build_mimicsqlstar_db/mimicsqlstar.db'
37
        newmodel = query(db_file)
38
        print('DONE')
39
40
    sql2sql_convertor = SQL2SQL()
41
42
    correct = 0
43
    newsqls = []
44
    for i, sql in enumerate(df['sql']):
45
        sql_answer = []
46
        newsql_answer = []
47
48
        print("-" * 50)
49
        print(i, sql)
50
51
        if execution:
52
            sql_res = orimodel.execute_sql(sql.lower()).fetchall()
53
            for res in sql_res:
54
                val = '|'
55
                temp = []
56
                for t in res:
57
                    val += str(t) + '|\t\t|'
58
                    temp.append(str(t))
59
                print(val[:-1])
60
                sql_answer.append(tuple(temp))
61
        print()
62
63
        new_sql = sql2sql_convertor.translate(sql)
64
65
        print(i, new_sql)
66
        if execution:
67
            newsql_res = newmodel.execute_sql(new_sql.lower()).fetchall()
68
            for res in newsql_res:
69
                val = '|'
70
                temp = []
71
                for t in res:
72
                    val += str(t) + '|\t\t|'
73
                    temp.append(str(t))
74
                print(val[:-1])
75
                newsql_answer.append(tuple(temp))
76
77
            print(sql_answer, newsql_answer, set(sql_answer) == set(newsql_answer))
78
            if set(sql_answer) == set(newsql_answer):
79
                correct += 1
80
            else:
81
                print("[incorrect]")
82
        print()
83
84
        new_sql = new_sql.lower()
85
        newsql_tok = tokenize_sql(new_sql)
86
        newsqls.append({'sql': new_sql, 'sql_tok': newsql_tok})
87
88
    if execution:
89
        print(f'[SQL2SQL] filenmae: {filename}, Answer Accuracy: {correct/len(df):.4f}')
90
91
    sql_data = []
92
    for d, sql_d in zip(data, newsqls):
93
        d['sql'] = sql_d['sql']
94
        d['sql_tok'] = sql_d['sql_tok']
95
        sql_data.append(d)
96
97
    save_filename = os.path.join(savedir, filename)
98
    with open(save_filename, 'w') as json_file:
99
        for dic in sql_data:
100
            json.dump(dic, json_file)
101
            json_file.write('\n')
102
103
    print(f"Write to {save_filename}")
104
105
106
def build_vocab(dataset_type='natural'):
107
    datadir = f'./dataset/mimicsqlstar/{dataset_type}/'
108
    filenames = ['train.json']
109
    counter = Counter()
110
    for filename in filenames:
111
        with open(os.path.join(datadir, filename)) as json_file:
112
            for line in json_file:
113
                dic = json.loads(line)
114
                counter.update(dic['question_refine_tok'])
115
                counter.update(dic['sql_tok'])
116
117
    with open(os.path.join(datadir, 'vocab'), 'w') as f:
118
        for k, v in counter.most_common():
119
120
            if len(k.split()) == 0:
121
                continue
122
123
            if k == ' ':
124
                continue
125
            f.write(f'{k} {v}\n')
126
127
    print(f'vocab builded: {len(counter)}')
128
129
130
if __name__ == '__main__':
131
    parser = argparse.ArgumentParser(description='mimicsql to mimicsql*')
132
    parser.add_argument('--dataset_type', type=str, default='natural', choices=['natural','template'])
133
    parser.add_argument('--execution', default=False, type=lambda x: (str(x).lower() == 'true'))
134
    args = parser.parse_args()
135
136
    execution = args.execution
137
    dataset_type = args.dataset_type
138
139
    filenames = ['train.json', 'dev.json', 'test.json']
140
    for filename in filenames:
141
        convert_sql2sparql(filename=filename, dataset_type=dataset_type, execution=execution)
142
    build_vocab(dataset_type=dataset_type)