--- a +++ b/convert_mimicsql2sqlstar_dataset.py @@ -0,0 +1,142 @@ +import sys +sys.path.append('..') +import json +import os +import argparse +import pandas as pd +from collections import Counter + +from mimicsql.evaluation.utils import query +from sql2sql import SQL2SQL + + +def tokenize_sql(q): + q = ' . '.join(q.split('.')) + q = ' , '.join(q.split(',')) + return q.split() + + +def convert_sql2sparql(filename='train.json', dataset_type='natural', execution=True): + savedir = f'./dataset/mimicsqlstar/{dataset_type}/' + datadir = f'./dataset/mimicsql/mimicsql_{dataset_type}/' + data = [] + with open(os.path.join(datadir, filename)) as json_file: + for line in json_file: + data.append(json.loads(line)) + + df = pd.DataFrame(data) + + if execution: + print(f'LOAD original mimic_db ... {len(df)}') + db_file = './mimicsql/evaluation/mimic_db/mimic.db' + orimodel = query(db_file) + print('DONE') + + print('LOAD mimicqlstar.db ...') + db_file = './build_mimicsqlstar_db/mimicsqlstar.db' + newmodel = query(db_file) + print('DONE') + + sql2sql_convertor = SQL2SQL() + + correct = 0 + newsqls = [] + for i, sql in enumerate(df['sql']): + sql_answer = [] + newsql_answer = [] + + print("-" * 50) + print(i, sql) + + if execution: + sql_res = orimodel.execute_sql(sql.lower()).fetchall() + for res in sql_res: + val = '|' + temp = [] + for t in res: + val += str(t) + '|\t\t|' + temp.append(str(t)) + print(val[:-1]) + sql_answer.append(tuple(temp)) + print() + + new_sql = sql2sql_convertor.translate(sql) + + print(i, new_sql) + if execution: + newsql_res = newmodel.execute_sql(new_sql.lower()).fetchall() + for res in newsql_res: + val = '|' + temp = [] + for t in res: + val += str(t) + '|\t\t|' + temp.append(str(t)) + print(val[:-1]) + newsql_answer.append(tuple(temp)) + + print(sql_answer, newsql_answer, set(sql_answer) == set(newsql_answer)) + if set(sql_answer) == set(newsql_answer): + correct += 1 + else: + print("[incorrect]") + print() + + new_sql = new_sql.lower() + newsql_tok = tokenize_sql(new_sql) + newsqls.append({'sql': new_sql, 'sql_tok': newsql_tok}) + + if execution: + print(f'[SQL2SQL] filenmae: {filename}, Answer Accuracy: {correct/len(df):.4f}') + + sql_data = [] + for d, sql_d in zip(data, newsqls): + d['sql'] = sql_d['sql'] + d['sql_tok'] = sql_d['sql_tok'] + sql_data.append(d) + + save_filename = os.path.join(savedir, filename) + with open(save_filename, 'w') as json_file: + for dic in sql_data: + json.dump(dic, json_file) + json_file.write('\n') + + print(f"Write to {save_filename}") + + +def build_vocab(dataset_type='natural'): + datadir = f'./dataset/mimicsqlstar/{dataset_type}/' + filenames = ['train.json'] + counter = Counter() + for filename in filenames: + with open(os.path.join(datadir, filename)) as json_file: + for line in json_file: + dic = json.loads(line) + counter.update(dic['question_refine_tok']) + counter.update(dic['sql_tok']) + + with open(os.path.join(datadir, 'vocab'), 'w') as f: + for k, v in counter.most_common(): + + if len(k.split()) == 0: + continue + + if k == ' ': + continue + f.write(f'{k} {v}\n') + + print(f'vocab builded: {len(counter)}') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='mimicsql to mimicsql*') + parser.add_argument('--dataset_type', type=str, default='natural', choices=['natural','template']) + parser.add_argument('--execution', default=False, type=lambda x: (str(x).lower() == 'true')) + args = parser.parse_args() + + execution = args.execution + dataset_type = args.dataset_type + + filenames = ['train.json', 'dev.json', 'test.json'] + for filename in filenames: + convert_sql2sparql(filename=filename, dataset_type=dataset_type, execution=execution) + build_vocab(dataset_type=dataset_type)