[ab27bc]: / mimicsql / evaluation / overall_evaluation.ipynb

Download this file

333 lines (332 with data), 10.0 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import csv\n",
    "import pandas\n",
    "import sqlite3\n",
    "import random\n",
    "import json\n",
    "import itertools\n",
    "import numpy as np\n",
    "from sumeval.metrics.rouge import RougeCalculator\n",
    "rouge = RougeCalculator(stopwords=False, lang=\"en\")\n",
    "\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "db_file = 'mimic_db/mimic.db'\n",
    "model = query(db_file)\n",
    "(db_meta, db_tabs, db_head) = model._load_db(db_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fp = open('mimic_db/lookup.json', 'r')\n",
    "lookup = json.load(fp)\n",
    "fp.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "headerDic = []\n",
    "for tb in lookup:\n",
    "    for hd in lookup[tb]:\n",
    "        headerDic.append('.'.join([tb,hd]).lower())\n",
    "# print(headerDic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fp = open('generated_sql/output.json', 'r')\n",
    "cnt = 0\n",
    "sql_rec = []\n",
    "for line in fp:\n",
    "    line = json.loads(line)\n",
    "    pred = re.split('<stop>', line['sql_pred'])[0]\n",
    "    ttt = line['sql_gold']\n",
    "\n",
    "#     print(pred)\n",
    "    predArr = re.split(\"where\", pred)\n",
    "    predAgg = re.split(\"\\s\", predArr[0])\n",
    "    predAgg = list(filter(None, predAgg))\n",
    "    predAgg2 = []\n",
    "    for k in range(len(predAgg)-1):\n",
    "        if predAgg[k] in headerDic and predAgg[k+1] in headerDic:\n",
    "            predAgg2.append(predAgg[k] + ',')\n",
    "        else:\n",
    "            predAgg2.append(predAgg[k])\n",
    "    predAgg2.append(predAgg[-1])\n",
    "    predAgg = ' '.join(predAgg2)\n",
    "    \n",
    "    predCon = re.split(\"and\", predArr[1])\n",
    "    predConNew = []\n",
    "    k = 0\n",
    "    while k < len(predCon):\n",
    "        if \"=\" in predCon[k] or \"<\" in predCon[k] or \">\" in predCon[k]:\n",
    "            predConNew.append(predCon[k])\n",
    "        else:\n",
    "            predConNew[-1] += \" and \" + predCon[k]\n",
    "            k += 1\n",
    "        k += 1\n",
    "    for k in range(len(predConNew)):\n",
    "        if \"=\" in predConNew[k]:\n",
    "            conOp = \"=\"\n",
    "        if \">\" in predConNew[k]:\n",
    "            conOp = \">\"\n",
    "        if \"<\" in predConNew[k]:\n",
    "            conOp = \"<\"\n",
    "        if \"<=\" in predConNew[k]:\n",
    "            conOp = \"<=\"\n",
    "        if \">=\" in predConNew[k]:\n",
    "            conOp = \">=\"\n",
    "        conVal = re.split(\"=|<|>\", predConNew[k])\n",
    "        conVal = list(filter(None, conVal))\n",
    "        conCol = conVal[0]\n",
    "        conColArr = re.split('\\.|\\s', conCol)\n",
    "        conColArr = list(filter(None, conColArr))\n",
    "        try:\n",
    "            pool_ = lookup[conColArr[0].upper()][conColArr[1].upper()]\n",
    "        except:\n",
    "            sql_rec.append([\"Error\", ttt])\n",
    "            continue\n",
    "        conVal = re.split('\"|\\s', conVal[-1])\n",
    "        conVal = list(filter(None, conVal))\n",
    "        conVal = ' '.join(conVal)\n",
    "        predConNew[k] = conCol + conOp + ' \"' + conVal + '\"'\n",
    "\n",
    "    pred = predAgg + ' where ' + ' and '.join(predConNew)\n",
    "    pred = re.split(\"\\s\", pred)\n",
    "    pred = list(filter(None, pred))\n",
    "    pred = \" \".join(pred)\n",
    "#     print(pred)\n",
    "#     print(ttt)\n",
    "#     print()\n",
    "    sql_rec.append([pred, ttt])\n",
    "    try:\n",
    "        myres = model.execute_sql(pred).fetchall()\n",
    "        myres = list({k[0]: {} for k in myres if not k[0] == None})\n",
    "        cnt += 1\n",
    "    except:\n",
    "        pass\n",
    "    \n",
    "fp.close()\n",
    "print(cnt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Logic Form Accuracy:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "cnt = 0\n",
    "for itm in sql_rec:\n",
    "    arr = re.split(',|\\s', itm[0].lower())\n",
    "    arr = list(filter(None, arr))\n",
    "    arr = ' '.join(arr)\n",
    "    if arr == itm[1]:\n",
    "        cnt += 1\n",
    "print('Logic Form Accuracy: {}'.format(cnt/1000))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fp = open('generated_sql/output.json', 'r')\n",
    "cnt = 0\n",
    "lblb = 0\n",
    "sql_rec2 = []\n",
    "for line in fp:\n",
    "    line = json.loads(line)\n",
    "    pred = re.split('<stop>', line['sql_pred'])[0]\n",
    "    ttt = line['sql_gold']\n",
    "\n",
    "#     print(pred)\n",
    "    predArr = re.split(\"where\", pred)\n",
    "    predAgg = re.split(\"\\s\", predArr[0])\n",
    "    predAgg = list(filter(None, predAgg))\n",
    "    predAgg2 = []\n",
    "    for k in range(len(predAgg)-1):\n",
    "        if predAgg[k] in headerDic and predAgg[k+1] in headerDic:\n",
    "            predAgg2.append(predAgg[k] + ',')\n",
    "        else:\n",
    "            predAgg2.append(predAgg[k])\n",
    "    predAgg2.append(predAgg[-1])\n",
    "    predAgg = ' '.join(predAgg2)\n",
    "    \n",
    "    predCon = re.split(\"and\", predArr[1])\n",
    "    predConNew = []\n",
    "    k = 0\n",
    "    while k < len(predCon):\n",
    "        if \"=\" in predCon[k] or \"<\" in predCon[k] or \">\" in predCon[k]:\n",
    "            predConNew.append(predCon[k])\n",
    "        else:\n",
    "            predConNew[-1] += \" and \" + predCon[k]\n",
    "            k += 1\n",
    "        k += 1\n",
    "    for k in range(len(predConNew)):\n",
    "        if \"=\" in predConNew[k]:\n",
    "            conOp = \"=\"\n",
    "        if \">\" in predConNew[k]:\n",
    "            conOp = \">\"\n",
    "        if \"<\" in predConNew[k]:\n",
    "            conOp = \"<\"\n",
    "        if \"<=\" in predConNew[k]:\n",
    "            conOp = \"<=\"\n",
    "        if \">=\" in predConNew[k]:\n",
    "            conOp = \">=\"\n",
    "        conVal = re.split(\"=|<|>\", predConNew[k])\n",
    "        conVal = list(filter(None, conVal))\n",
    "        conCol = conVal[0]\n",
    "        conColArr = re.split('\\.|\\s', conCol)\n",
    "        conColArr = list(filter(None, conColArr))\n",
    "        try:\n",
    "            pool_ = lookup[conColArr[0].upper()][conColArr[1].upper()]\n",
    "        except:\n",
    "            sql_rec2.append([\"Error\", ttt])\n",
    "            lblb = 1\n",
    "            break\n",
    "        conVal = re.split('\"|\\s', conVal[-1])\n",
    "        conVal = list(filter(None, conVal))\n",
    "        conVal = ' '.join(conVal)\n",
    "        try:\n",
    "            int(conVal)\n",
    "            predConNew[k] = conCol + conOp + ' \"' + conVal + '\"'\n",
    "        except:\n",
    "            predConNew[k] = 'lower(' + conCol + ')' + conOp + ' \"' + conVal + '\"'\n",
    "    \n",
    "    if lblb ==1:\n",
    "        lblb = 0\n",
    "        continue\n",
    "    pred = predAgg + ' where ' + ' and '.join(predConNew)\n",
    "    pred = re.split(\"\\s\", pred)\n",
    "    pred = list(filter(None, pred))\n",
    "    pred = \" \".join(pred)\n",
    "#     print(pred)\n",
    "#     print(ttt)\n",
    "#     print()\n",
    "    sql_rec2.append([pred, ttt])\n",
    "    try:\n",
    "        myres = model.execute_sql(pred).fetchall()\n",
    "        myres = list({k[0]: {} for k in myres if not k[0] == None})\n",
    "        cnt += 1\n",
    "    except:\n",
    "        pass\n",
    "    \n",
    "fp.close()\n",
    "print(cnt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fp = open(\"../mimicsql_data/mimicsql_natural/test.json\", 'r')\n",
    "cnt = 0\n",
    "for line in fp:\n",
    "    data = json.loads(line)\n",
    "    sql_rec2[cnt][1] = data['sql']\n",
    "    cnt += 1\n",
    "fp.close()\n",
    "print(cnt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Execution Accuracy:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnt = 0\n",
    "k = 0\n",
    "for itm in sql_rec2:\n",
    "    [pred, ttt] = itm\n",
    "    try:\n",
    "        outPred = model.execute_sql(pred).fetchall()\n",
    "        outTtt = model.execute_sql(ttt).fetchall()\n",
    "    except:\n",
    "        k += 1\n",
    "        continue\n",
    "    if outPred == outTtt:\n",
    "        cnt += 1\n",
    "#     else:\n",
    "#         if sql_rec[k][0] == sql_rec[k][1].lower():\n",
    "#             print(pred)\n",
    "#             print(sql_rec[k][0])\n",
    "#             print(ttt.lower())\n",
    "#             print(ttt)\n",
    "#             print(outPred)\n",
    "#             print(outTtt)\n",
    "#             print()\n",
    "    k += 1\n",
    "print('Execution Accuracy: {}'.format(cnt/1000))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}