281 lines (280 with data), 7.9 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import re\n",
"import csv\n",
"import pandas\n",
"import sqlite3\n",
"import random\n",
"import json\n",
"import itertools\n",
"import numpy as np\n",
"from sumeval.metrics.rouge import RougeCalculator\n",
"rouge = RougeCalculator(stopwords=False, lang=\"en\")\n",
"\n",
"from utils import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"db_file = 'mimic_db/mimic.db'\n",
"model = query(db_file)\n",
"(db_meta, db_tabs, db_head) = model._load_db(db_file)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fp = open('mimic_db/lookup.json', 'r')\n",
"lookup = json.load(fp)\n",
"fp.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"headerDic = []\n",
"for tb in lookup:\n",
" for hd in lookup[tb]:\n",
" headerDic.append('.'.join([tb,hd]).lower())\n",
"# print(headerDic)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def find_best(input_, pool_):\n",
" score_ = []\n",
" for itm in pool_:\n",
" input_ = input_.lower()\n",
" itm = str(itm).lower()\n",
" score_.append(rouge.rouge_n(summary=input_, references=itm, n=1))\n",
" \n",
" if np.sum(score_) == 0:\n",
" score_ = []\n",
" input2_ = ' '.join(list(input_)).lower()\n",
" for itm in pool_:\n",
" itm2 = ' '.join(list(str(itm))).lower()\n",
" score_.append(rouge.rouge_n(summary=input2_, references=itm2, n=1))\n",
" \n",
" return str(pool_[np.argmax(score_)])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fp = open('generated_sql/output.json', 'r')\n",
"cnt = 0\n",
"lblb = 0\n",
"sql_rec = []\n",
"for line in fp:\n",
" line = json.loads(line)\n",
" pred = re.split('<stop>', line['sql_pred'])[0]\n",
" ttt = line['sql_gold']\n",
"\n",
"# print(pred)\n",
" predArr = re.split(\"where\", pred)\n",
" predAgg = re.split(\"\\s\", predArr[0])\n",
" predAgg = list(filter(None, predAgg))\n",
" predAgg2 = []\n",
" for k in range(len(predAgg)-1):\n",
" if predAgg[k] in headerDic and predAgg[k+1] in headerDic:\n",
" predAgg2.append(predAgg[k] + ',')\n",
" else:\n",
" predAgg2.append(predAgg[k])\n",
" predAgg2.append(predAgg[-1])\n",
" predAgg = ' '.join(predAgg2)\n",
" \n",
" predCon = re.split(\"and\", predArr[1])\n",
" predConNew = []\n",
" k = 0\n",
" while k < len(predCon):\n",
" if \"=\" in predCon[k] or \"<\" in predCon[k] or \">\" in predCon[k]:\n",
" predConNew.append(predCon[k])\n",
" else:\n",
" predConNew[-1] += \" and \" + predCon[k]\n",
" k += 1\n",
" k += 1\n",
" for k in range(len(predConNew)):\n",
" if \"=\" in predConNew[k]:\n",
" conOp = \"=\"\n",
" if \">\" in predConNew[k]:\n",
" conOp = \">\"\n",
" if \"<\" in predConNew[k]:\n",
" conOp = \"<\"\n",
" if \"<=\" in predConNew[k]:\n",
" conOp = \"<=\"\n",
" if \">=\" in predConNew[k]:\n",
" conOp = \">=\"\n",
" conVal = re.split(\"=|<|>\", predConNew[k])\n",
" conVal = list(filter(None, conVal))\n",
" conCol = conVal[0]\n",
" conColArr = re.split('\\.|\\s', conCol)\n",
" conColArr = list(filter(None, conColArr))\n",
" try:\n",
" pool_ = lookup[conColArr[0].upper()][conColArr[1].upper()]\n",
" except:\n",
" sql_rec.append([\"Error\", ttt])\n",
" lblb = 1\n",
" break\n",
" conVal = re.split('\"|\\s', conVal[-1])\n",
" conVal = list(filter(None, conVal))\n",
" conVal = ' '.join(conVal)\n",
" predConNew[k] = conCol + conOp + ' \"' + find_best(conVal, pool_) + '\"'\n",
" if lblb == 1:\n",
" lblb = 0\n",
" continue\n",
"\n",
" pred = predAgg + ' where ' + ' and '.join(predConNew)\n",
" pred = re.split(\"\\s\", pred)\n",
" pred = list(filter(None, pred))\n",
" pred = \" \".join(pred)\n",
"# print(pred)\n",
"# print(ttt)\n",
"# print()\n",
" sql_rec.append([pred, ttt])\n",
" try:\n",
" myres = model.execute_sql(pred).fetchall()\n",
" myres = list({k[0]: {} for k in myres if not k[0] == None})\n",
" cnt += 1\n",
" except:\n",
" pass\n",
" \n",
"fp.close()\n",
"print(cnt)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Save the generated sql after recovering condition values\n",
"# Will be used for break-down performance\n",
"fout = open('generated_sql/recovered_output.json', 'w')\n",
"for itm in sql_rec:\n",
" tmp = {'sql_gold': itm[1], 'sql_pred_recovered': itm[0]}\n",
" data = json.dumps(tmp)\n",
" fout.write(data+'\\n')\n",
"fout.close()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Logic Form Accuracy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"cnt = 0\n",
"for itm in sql_rec:\n",
" arr = re.split(',|\\s', itm[0].lower())\n",
" arr = list(filter(None, arr))\n",
" arr = ' '.join(arr)\n",
" if arr == itm[1]:\n",
" cnt += 1\n",
"print('Logic Form Accuracy: {}'.format(cnt/1000))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fp = open(\"../mimicsql_data/mimicsql_natural/test.json\", 'r')\n",
"cnt = 0\n",
"for line in fp:\n",
" data = json.loads(line)\n",
" sql_rec[cnt][1] = data['sql']\n",
" cnt += 1\n",
"fp.close()\n",
"print(sql_rec[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Execution Accuracy:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cnt = 0\n",
"for itm in sql_rec:\n",
" [pred, ttt] = itm\n",
" try:\n",
" outPred = model.execute_sql(pred).fetchall()\n",
" outTtt = model.execute_sql(ttt).fetchall()\n",
" except:\n",
"# print('Fail')\n",
" continue\n",
" if outPred == outTtt:\n",
" cnt += 1\n",
"# else:\n",
"# print(pred)\n",
"# print(ttt.lower())\n",
"# print(outPred)\n",
"# print(outTtt)\n",
"# print()\n",
"print('Execution Accuracy: {}'.format(cnt/1000))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}