[ab27bc]: / mimicsql / evaluation / breakdown_evaluation_with_recover.ipynb

Download this file

259 lines (258 with data), 6.5 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import csv\n",
    "import pandas\n",
    "import sqlite3\n",
    "import random\n",
    "import json\n",
    "import itertools\n",
    "import numpy as np\n",
    "from sumeval.metrics.rouge import RougeCalculator\n",
    "rouge = RougeCalculator(stopwords=False, lang=\"en\")\n",
    "\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "db_file = 'mimic_db/mimic.db'\n",
    "model = query(db_file)\n",
    "(db_meta, db_tabs, db_head) = model._load_db(db_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "headerDic = []\n",
    "for tb in db_head:\n",
    "    for hd in db_head[tb]:\n",
    "        headerDic.append('.'.join([tb, hd]).lower())\n",
    "# print(headerDic)\n",
    "tableDic = []\n",
    "for tb in db_head:\n",
    "    tableDic.append(tb.lower())\n",
    "# print(tableDic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_sql(sql):\n",
    "    \n",
    "    sqlForm = {}\n",
    "    \n",
    "    arr = re.split('where', sql)\n",
    "    qlead = re.split('from', arr[0])\n",
    "    qagg = re.split('\\s', qlead[0])\n",
    "    qagg = list(filter(None, qagg))\n",
    "#     print(qagg)\n",
    "    if qagg[1] == 'count' or qagg[1] == 'min' or qagg[1] == 'max' or qagg[1] == 'avg':\n",
    "        sqlForm['sel'] = qagg[1]\n",
    "    else:\n",
    "        sqlForm['sel'] = ''\n",
    "        \n",
    "    itm = []\n",
    "    for wd in qagg:\n",
    "        if wd in headerDic:\n",
    "            itm.append(wd)\n",
    "    sqlForm['agg'] = itm\n",
    "    \n",
    "    itm = []\n",
    "    qtab = re.split('\\s', qlead[1])\n",
    "    qtab = list(filter(None, qtab))\n",
    "    for wd in qtab:\n",
    "        if wd in tableDic:\n",
    "            itm.append(wd)\n",
    "    sqlForm['tab'] = itm\n",
    "        \n",
    "    qtail = re.split('and', arr[-1])\n",
    "    itm = []\n",
    "    for cond in qtail:\n",
    "        cond = re.split('\\s', cond)\n",
    "        cond = list(filter(None, cond))\n",
    "        if len(cond) > 2:\n",
    "            condVal = ' '.join(cond[2:])\n",
    "            condVal = re.split('\\\"|\\s', condVal)\n",
    "            condVal = ' '.join(list(filter(None, condVal)))\n",
    "            itm.append(cond[:2] + [condVal])\n",
    "    sqlForm['cond'] = sorted(itm)\n",
    "    \n",
    "    return sqlForm\n",
    "\n",
    "fp = open('generated_sql/recovered_output.json', 'r')\n",
    "outGen = []\n",
    "outTtt = []\n",
    "for line in fp:\n",
    "    line = json.loads(line)\n",
    "    try:\n",
    "        gen = line['sql_pred_recovered'].lower()\n",
    "        gen = re.split('<stop>', gen)[0]\n",
    "        sqlG = parse_sql(gen)\n",
    "        outGen.append(sqlG)\n",
    "\n",
    "        ttt = line['sql_gold']\n",
    "        sqlT = parse_sql(ttt)\n",
    "        outTtt.append(sqlT)\n",
    "    except:\n",
    "        continue\n",
    "fp.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnt = 0\n",
    "for k in range(len(outGen)):\n",
    "    if outGen[k] == outTtt[k]:\n",
    "        cnt += 1\n",
    "print('Overall logic form accuracy: {}'.format(cnt/1000))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Aggregation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnt = 0\n",
    "for k in range(len(outGen)):\n",
    "    if outGen[k]['sel'] == outTtt[k]['sel']:\n",
    "        cnt += 1\n",
    "print('Break-down accuracy on AGGREGATION OPERATION: {}'.format(cnt/1000))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnt = 0\n",
    "for k in range(len(outGen)):\n",
    "    if outGen[k]['agg'] == outTtt[k]['agg']:\n",
    "        cnt += 1\n",
    "print('Break-down accuracy on AGGREGATION COLUMN: {}'.format(cnt/1000))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnt = 0\n",
    "for k in range(len(outGen)):\n",
    "    if outGen[k]['tab'] == outTtt[k]['tab']:\n",
    "        cnt += 1\n",
    "print('Break-down accuracy on TABLE: {}'.format(cnt/1000))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Condition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnt = 0\n",
    "for k in range(len(outGen)):\n",
    "    arrG = [wd[0] for wd in outGen[k]['cond']]\n",
    "    arrT = [wd[0] for wd in outTtt[k]['cond']]\n",
    "    if arrG == arrT:\n",
    "        cnt += 1\n",
    "print(cnt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnt = 0\n",
    "for k in range(len(outGen)):\n",
    "    arrG = [wd[:2] for wd in outGen[k]['cond']]\n",
    "    arrT = [wd[:2] for wd in outTtt[k]['cond']]\n",
    "    if arrG == arrT:\n",
    "        cnt += 1\n",
    "print('Break-down accuracy on CONDITION COLUMN AND OPERATION: {}'.format(cnt/1000))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnt = 0\n",
    "for k in range(len(outGen)):\n",
    "    arrG = [wd[:3] for wd in outGen[k]['cond']]\n",
    "    arrT = [wd[:3] for wd in outTtt[k]['cond']]\n",
    "    if arrG == arrT:\n",
    "        cnt += 1\n",
    "print('Break-down accuracy on CONDITION VALUE: {}'.format(cnt/1000))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}