259 lines (258 with data), 6.5 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import re\n",
"import csv\n",
"import pandas\n",
"import sqlite3\n",
"import random\n",
"import json\n",
"import itertools\n",
"import numpy as np\n",
"from sumeval.metrics.rouge import RougeCalculator\n",
"rouge = RougeCalculator(stopwords=False, lang=\"en\")\n",
"\n",
"from utils import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"db_file = 'mimic_db/mimic.db'\n",
"model = query(db_file)\n",
"(db_meta, db_tabs, db_head) = model._load_db(db_file)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"headerDic = []\n",
"for tb in db_head:\n",
" for hd in db_head[tb]:\n",
" headerDic.append('.'.join([tb, hd]).lower())\n",
"# print(headerDic)\n",
"tableDic = []\n",
"for tb in db_head:\n",
" tableDic.append(tb.lower())\n",
"# print(tableDic)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def parse_sql(sql):\n",
" \n",
" sqlForm = {}\n",
" \n",
" arr = re.split('where', sql)\n",
" qlead = re.split('from', arr[0])\n",
" qagg = re.split('\\s', qlead[0])\n",
" qagg = list(filter(None, qagg))\n",
"# print(qagg)\n",
" if qagg[1] == 'count' or qagg[1] == 'min' or qagg[1] == 'max' or qagg[1] == 'avg':\n",
" sqlForm['sel'] = qagg[1]\n",
" else:\n",
" sqlForm['sel'] = ''\n",
" \n",
" itm = []\n",
" for wd in qagg:\n",
" if wd in headerDic:\n",
" itm.append(wd)\n",
" sqlForm['agg'] = itm\n",
" \n",
" itm = []\n",
" qtab = re.split('\\s', qlead[1])\n",
" qtab = list(filter(None, qtab))\n",
" for wd in qtab:\n",
" if wd in tableDic:\n",
" itm.append(wd)\n",
" sqlForm['tab'] = itm\n",
" \n",
" qtail = re.split('and', arr[-1])\n",
" itm = []\n",
" for cond in qtail:\n",
" cond = re.split('\\s', cond)\n",
" cond = list(filter(None, cond))\n",
" if len(cond) > 2:\n",
" condVal = ' '.join(cond[2:])\n",
" condVal = re.split('\\\"|\\s', condVal)\n",
" condVal = ' '.join(list(filter(None, condVal)))\n",
" itm.append(cond[:2] + [condVal])\n",
" sqlForm['cond'] = sorted(itm)\n",
" \n",
" return sqlForm\n",
"\n",
"fp = open('generated_sql/recovered_output.json', 'r')\n",
"outGen = []\n",
"outTtt = []\n",
"for line in fp:\n",
" line = json.loads(line)\n",
" try:\n",
" gen = line['sql_pred_recovered'].lower()\n",
" gen = re.split('<stop>', gen)[0]\n",
" sqlG = parse_sql(gen)\n",
" outGen.append(sqlG)\n",
"\n",
" ttt = line['sql_gold']\n",
" sqlT = parse_sql(ttt)\n",
" outTtt.append(sqlT)\n",
" except:\n",
" continue\n",
"fp.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cnt = 0\n",
"for k in range(len(outGen)):\n",
" if outGen[k] == outTtt[k]:\n",
" cnt += 1\n",
"print('Overall logic form accuracy: {}'.format(cnt/1000))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Aggregation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cnt = 0\n",
"for k in range(len(outGen)):\n",
" if outGen[k]['sel'] == outTtt[k]['sel']:\n",
" cnt += 1\n",
"print('Break-down accuracy on AGGREGATION OPERATION: {}'.format(cnt/1000))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cnt = 0\n",
"for k in range(len(outGen)):\n",
" if outGen[k]['agg'] == outTtt[k]['agg']:\n",
" cnt += 1\n",
"print('Break-down accuracy on AGGREGATION COLUMN: {}'.format(cnt/1000))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Table"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cnt = 0\n",
"for k in range(len(outGen)):\n",
" if outGen[k]['tab'] == outTtt[k]['tab']:\n",
" cnt += 1\n",
"print('Break-down accuracy on TABLE: {}'.format(cnt/1000))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Condition"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cnt = 0\n",
"for k in range(len(outGen)):\n",
" arrG = [wd[0] for wd in outGen[k]['cond']]\n",
" arrT = [wd[0] for wd in outTtt[k]['cond']]\n",
" if arrG == arrT:\n",
" cnt += 1\n",
"print(cnt)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cnt = 0\n",
"for k in range(len(outGen)):\n",
" arrG = [wd[:2] for wd in outGen[k]['cond']]\n",
" arrT = [wd[:2] for wd in outTtt[k]['cond']]\n",
" if arrG == arrT:\n",
" cnt += 1\n",
"print('Break-down accuracy on CONDITION COLUMN AND OPERATION: {}'.format(cnt/1000))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cnt = 0\n",
"for k in range(len(outGen)):\n",
" arrG = [wd[:3] for wd in outGen[k]['cond']]\n",
" arrT = [wd[:3] for wd in outTtt[k]['cond']]\n",
" if arrG == arrT:\n",
" cnt += 1\n",
"print('Break-down accuracy on CONDITION VALUE: {}'.format(cnt/1000))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}