a b/mimicsql/evaluation/breakdown_evaluation.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": null,
6
   "metadata": {},
7
   "outputs": [],
8
   "source": [
9
    "import re\n",
10
    "import csv\n",
11
    "import pandas\n",
12
    "import sqlite3\n",
13
    "import random\n",
14
    "import json\n",
15
    "import itertools\n",
16
    "import numpy as np\n",
17
    "from sumeval.metrics.rouge import RougeCalculator\n",
18
    "rouge = RougeCalculator(stopwords=False, lang=\"en\")\n",
19
    "\n",
20
    "from utils import *"
21
   ]
22
  },
23
  {
24
   "cell_type": "code",
25
   "execution_count": null,
26
   "metadata": {},
27
   "outputs": [],
28
   "source": [
29
    "db_file = 'mimic_db/mimic.db'\n",
30
    "model = query(db_file)\n",
31
    "(db_meta, db_tabs, db_head) = model._load_db(db_file)"
32
   ]
33
  },
34
  {
35
   "cell_type": "code",
36
   "execution_count": null,
37
   "metadata": {},
38
   "outputs": [],
39
   "source": [
40
    "headerDic = []\n",
41
    "for tb in db_head:\n",
42
    "    for hd in db_head[tb]:\n",
43
    "        headerDic.append('.'.join([tb, hd]).lower())\n",
44
    "# print(headerDic)\n",
45
    "tableDic = []\n",
46
    "for tb in db_head:\n",
47
    "    tableDic.append(tb.lower())\n",
48
    "# print(tableDic)"
49
   ]
50
  },
51
  {
52
   "cell_type": "code",
53
   "execution_count": null,
54
   "metadata": {},
55
   "outputs": [],
56
   "source": [
57
    "def parse_sql(sql):\n",
58
    "    \n",
59
    "    sqlForm = {}\n",
60
    "    \n",
61
    "    arr = re.split('where', sql)\n",
62
    "    qlead = re.split('from', arr[0])\n",
63
    "    qagg = re.split('\\s', qlead[0])\n",
64
    "    qagg = list(filter(None, qagg))\n",
65
    "    if qagg[1] == 'count' or qagg[1] == 'min' or qagg[1] == 'max' or qagg[1] == 'avg':\n",
66
    "        sqlForm['sel'] = qagg[1]\n",
67
    "    else:\n",
68
    "        sqlForm['sel'] = ''\n",
69
    "        \n",
70
    "    itm = []\n",
71
    "    for wd in qagg:\n",
72
    "        if wd in headerDic:\n",
73
    "            itm.append(wd)\n",
74
    "    sqlForm['agg'] = itm\n",
75
    "    \n",
76
    "    itm = []\n",
77
    "    qtab = re.split('\\s', qlead[1])\n",
78
    "    qtab = list(filter(None, qtab))\n",
79
    "    for wd in qtab:\n",
80
    "        if wd in tableDic:\n",
81
    "            itm.append(wd)\n",
82
    "    sqlForm['tab'] = itm\n",
83
    "        \n",
84
    "    qtail = re.split('and', arr[-1])\n",
85
    "    itm = []\n",
86
    "    for cond in qtail:\n",
87
    "        cond = re.split('\\s', cond)\n",
88
    "        cond = list(filter(None, cond))\n",
89
    "        if len(cond) > 2:\n",
90
    "            condVal = ' '.join(cond[2:])\n",
91
    "            condVal = re.split('\\\"|\\s', condVal)\n",
92
    "            condVal = ' '.join(list(filter(None, condVal)))\n",
93
    "            itm.append(cond[:2] + [condVal])\n",
94
    "    sqlForm['cond'] = sorted(itm)\n",
95
    "    \n",
96
    "    return sqlForm\n",
97
    "\n",
98
    "fp = open('generated_sql/output.json', 'r')\n",
99
    "outGen = []\n",
100
    "outTtt = []\n",
101
    "for line in fp:\n",
102
    "    line = json.loads(line)\n",
103
    "    gen = re.split('<stop>', line['sql_pred'])[0]\n",
104
    "    sqlG = parse_sql(gen)\n",
105
    "    outGen.append(sqlG)\n",
106
    "    \n",
107
    "    ttt = line['sql_gold']\n",
108
    "    sqlT = parse_sql(ttt)\n",
109
    "    outTtt.append(sqlT)\n",
110
    "#     print(sqlG)\n",
111
    "#     print(sqlT)\n",
112
    "#     break\n",
113
    "fp.close()"
114
   ]
115
  },
116
  {
117
   "cell_type": "code",
118
   "execution_count": null,
119
   "metadata": {},
120
   "outputs": [],
121
   "source": [
122
    "cnt = 0\n",
123
    "for k in range(len(outGen)):\n",
124
    "    if outGen[k] == outTtt[k]:\n",
125
    "        cnt += 1\n",
126
    "print('Overall logic form accuracy: {}'.format(cnt/1000))"
127
   ]
128
  },
129
  {
130
   "cell_type": "markdown",
131
   "metadata": {},
132
   "source": [
133
    "## Aggregation"
134
   ]
135
  },
136
  {
137
   "cell_type": "code",
138
   "execution_count": null,
139
   "metadata": {},
140
   "outputs": [],
141
   "source": [
142
    "cnt = 0\n",
143
    "for k in range(len(outGen)):\n",
144
    "    if outGen[k]['sel'] == outTtt[k]['sel']:\n",
145
    "        cnt += 1\n",
146
    "print('Break-down accuracy on AGGREGATION OPERATION: {}'.format(cnt/1000))"
147
   ]
148
  },
149
  {
150
   "cell_type": "code",
151
   "execution_count": null,
152
   "metadata": {},
153
   "outputs": [],
154
   "source": [
155
    "cnt = 0\n",
156
    "for k in range(len(outGen)):\n",
157
    "    if outGen[k]['agg'] == outTtt[k]['agg']:\n",
158
    "        cnt += 1\n",
159
    "print('Break-down accuracy on AGGREGATION COLUMN: {}'.format(cnt/1000))"
160
   ]
161
  },
162
  {
163
   "cell_type": "markdown",
164
   "metadata": {},
165
   "source": [
166
    "## Table"
167
   ]
168
  },
169
  {
170
   "cell_type": "code",
171
   "execution_count": null,
172
   "metadata": {},
173
   "outputs": [],
174
   "source": [
175
    "cnt = 0\n",
176
    "for k in range(len(outGen)):\n",
177
    "    if outGen[k]['tab'] == outTtt[k]['tab']:\n",
178
    "        cnt += 1\n",
179
    "print('Break-down accuracy on TABLE: {}'.format(cnt/1000))"
180
   ]
181
  },
182
  {
183
   "cell_type": "markdown",
184
   "metadata": {},
185
   "source": [
186
    "## Condition"
187
   ]
188
  },
189
  {
190
   "cell_type": "code",
191
   "execution_count": null,
192
   "metadata": {},
193
   "outputs": [],
194
   "source": [
195
    "cnt = 0\n",
196
    "for k in range(len(outGen)):\n",
197
    "    arrG = [wd[0] for wd in outGen[k]['cond']]\n",
198
    "    arrT = [wd[0] for wd in outTtt[k]['cond']]\n",
199
    "    if arrG == arrT:\n",
200
    "        cnt += 1\n",
201
    "print(cnt)"
202
   ]
203
  },
204
  {
205
   "cell_type": "code",
206
   "execution_count": null,
207
   "metadata": {},
208
   "outputs": [],
209
   "source": [
210
    "cnt = 0\n",
211
    "for k in range(len(outGen)):\n",
212
    "    arrG = [wd[:2] for wd in outGen[k]['cond']]\n",
213
    "    arrT = [wd[:2] for wd in outTtt[k]['cond']]\n",
214
    "    if arrG == arrT:\n",
215
    "        cnt += 1\n",
216
    "print('Break-down accuracy on CONDITION COLUMN AND OPERATION: {}'.format(cnt/1000))"
217
   ]
218
  },
219
  {
220
   "cell_type": "code",
221
   "execution_count": null,
222
   "metadata": {},
223
   "outputs": [],
224
   "source": [
225
    "cnt = 0\n",
226
    "for k in range(len(outGen)):\n",
227
    "    arrG = [wd[:3] for wd in outGen[k]['cond']]\n",
228
    "    arrT = [wd[:3] for wd in outTtt[k]['cond']]\n",
229
    "    if arrG == arrT:\n",
230
    "        cnt += 1\n",
231
    "print('Break-down accuracy on CONDITION VALUE: {}'.format(cnt/1000))"
232
   ]
233
  }
234
 ],
235
 "metadata": {
236
  "kernelspec": {
237
   "display_name": "Python 3",
238
   "language": "python",
239
   "name": "python3"
240
  },
241
  "language_info": {
242
   "codemirror_mode": {
243
    "name": "ipython",
244
    "version": 3
245
   },
246
   "file_extension": ".py",
247
   "mimetype": "text/x-python",
248
   "name": "python",
249
   "nbconvert_exporter": "python",
250
   "pygments_lexer": "ipython3",
251
   "version": "3.6.7"
252
  }
253
 },
254
 "nbformat": 4,
255
 "nbformat_minor": 2
256
}