[ab27bc]: / mimicsql / evaluation / overall_evaluation_with_recover.ipynb

Download this file

281 lines (280 with data), 7.9 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import csv\n",
    "import pandas\n",
    "import sqlite3\n",
    "import random\n",
    "import json\n",
    "import itertools\n",
    "import numpy as np\n",
    "from sumeval.metrics.rouge import RougeCalculator\n",
    "rouge = RougeCalculator(stopwords=False, lang=\"en\")\n",
    "\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "db_file = 'mimic_db/mimic.db'\n",
    "model = query(db_file)\n",
    "(db_meta, db_tabs, db_head) = model._load_db(db_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fp = open('mimic_db/lookup.json', 'r')\n",
    "lookup = json.load(fp)\n",
    "fp.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "headerDic = []\n",
    "for tb in lookup:\n",
    "    for hd in lookup[tb]:\n",
    "        headerDic.append('.'.join([tb,hd]).lower())\n",
    "# print(headerDic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_best(input_, pool_):\n",
    "    score_ = []\n",
    "    for itm in pool_:\n",
    "        input_ = input_.lower()\n",
    "        itm = str(itm).lower()\n",
    "        score_.append(rouge.rouge_n(summary=input_, references=itm, n=1))\n",
    "    \n",
    "    if np.sum(score_) == 0:\n",
    "        score_ = []\n",
    "        input2_ = ' '.join(list(input_)).lower()\n",
    "        for itm in pool_:\n",
    "            itm2 = ' '.join(list(str(itm))).lower()\n",
    "            score_.append(rouge.rouge_n(summary=input2_, references=itm2, n=1))\n",
    "    \n",
    "    return str(pool_[np.argmax(score_)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fp = open('generated_sql/output.json', 'r')\n",
    "cnt = 0\n",
    "lblb = 0\n",
    "sql_rec = []\n",
    "for line in fp:\n",
    "    line = json.loads(line)\n",
    "    pred = re.split('<stop>', line['sql_pred'])[0]\n",
    "    ttt = line['sql_gold']\n",
    "\n",
    "#     print(pred)\n",
    "    predArr = re.split(\"where\", pred)\n",
    "    predAgg = re.split(\"\\s\", predArr[0])\n",
    "    predAgg = list(filter(None, predAgg))\n",
    "    predAgg2 = []\n",
    "    for k in range(len(predAgg)-1):\n",
    "        if predAgg[k] in headerDic and predAgg[k+1] in headerDic:\n",
    "            predAgg2.append(predAgg[k] + ',')\n",
    "        else:\n",
    "            predAgg2.append(predAgg[k])\n",
    "    predAgg2.append(predAgg[-1])\n",
    "    predAgg = ' '.join(predAgg2)\n",
    "    \n",
    "    predCon = re.split(\"and\", predArr[1])\n",
    "    predConNew = []\n",
    "    k = 0\n",
    "    while k < len(predCon):\n",
    "        if \"=\" in predCon[k] or \"<\" in predCon[k] or \">\" in predCon[k]:\n",
    "            predConNew.append(predCon[k])\n",
    "        else:\n",
    "            predConNew[-1] += \" and \" + predCon[k]\n",
    "            k += 1\n",
    "        k += 1\n",
    "    for k in range(len(predConNew)):\n",
    "        if \"=\" in predConNew[k]:\n",
    "            conOp = \"=\"\n",
    "        if \">\" in predConNew[k]:\n",
    "            conOp = \">\"\n",
    "        if \"<\" in predConNew[k]:\n",
    "            conOp = \"<\"\n",
    "        if \"<=\" in predConNew[k]:\n",
    "            conOp = \"<=\"\n",
    "        if \">=\" in predConNew[k]:\n",
    "            conOp = \">=\"\n",
    "        conVal = re.split(\"=|<|>\", predConNew[k])\n",
    "        conVal = list(filter(None, conVal))\n",
    "        conCol = conVal[0]\n",
    "        conColArr = re.split('\\.|\\s', conCol)\n",
    "        conColArr = list(filter(None, conColArr))\n",
    "        try:\n",
    "            pool_ = lookup[conColArr[0].upper()][conColArr[1].upper()]\n",
    "        except:\n",
    "            sql_rec.append([\"Error\", ttt])\n",
    "            lblb = 1\n",
    "            break\n",
    "        conVal = re.split('\"|\\s', conVal[-1])\n",
    "        conVal = list(filter(None, conVal))\n",
    "        conVal = ' '.join(conVal)\n",
    "        predConNew[k] = conCol + conOp + ' \"' + find_best(conVal, pool_) + '\"'\n",
    "    if lblb == 1:\n",
    "        lblb = 0\n",
    "        continue\n",
    "\n",
    "    pred = predAgg + ' where ' + ' and '.join(predConNew)\n",
    "    pred = re.split(\"\\s\", pred)\n",
    "    pred = list(filter(None, pred))\n",
    "    pred = \" \".join(pred)\n",
    "#     print(pred)\n",
    "#     print(ttt)\n",
    "#     print()\n",
    "    sql_rec.append([pred, ttt])\n",
    "    try:\n",
    "        myres = model.execute_sql(pred).fetchall()\n",
    "        myres = list({k[0]: {} for k in myres if not k[0] == None})\n",
    "        cnt += 1\n",
    "    except:\n",
    "        pass\n",
    "    \n",
    "fp.close()\n",
    "print(cnt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the generated sql after recovering condition values\n",
    "# Will be used for break-down performance\n",
    "fout = open('generated_sql/recovered_output.json', 'w')\n",
    "for itm in sql_rec:\n",
    "    tmp = {'sql_gold': itm[1], 'sql_pred_recovered': itm[0]}\n",
    "    data = json.dumps(tmp)\n",
    "    fout.write(data+'\\n')\n",
    "fout.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Logic Form Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "cnt = 0\n",
    "for itm in sql_rec:\n",
    "    arr = re.split(',|\\s', itm[0].lower())\n",
    "    arr = list(filter(None, arr))\n",
    "    arr = ' '.join(arr)\n",
    "    if arr == itm[1]:\n",
    "        cnt += 1\n",
    "print('Logic Form Accuracy: {}'.format(cnt/1000))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fp = open(\"../mimicsql_data/mimicsql_natural/test.json\", 'r')\n",
    "cnt = 0\n",
    "for line in fp:\n",
    "    data = json.loads(line)\n",
    "    sql_rec[cnt][1] = data['sql']\n",
    "    cnt += 1\n",
    "fp.close()\n",
    "print(sql_rec[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Execution Accuracy:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnt = 0\n",
    "for itm in sql_rec:\n",
    "    [pred, ttt] = itm\n",
    "    try:\n",
    "        outPred = model.execute_sql(pred).fetchall()\n",
    "        outTtt = model.execute_sql(ttt).fetchall()\n",
    "    except:\n",
    "#         print('Fail')\n",
    "        continue\n",
    "    if outPred == outTtt:\n",
    "        cnt += 1\n",
    "#     else:\n",
    "#         print(pred)\n",
    "#         print(ttt.lower())\n",
    "#         print(outPred)\n",
    "#         print(outTtt)\n",
    "#         print()\n",
    "print('Execution Accuracy: {}'.format(cnt/1000))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}