a b/mimicsql/evaluation/overall_evaluation.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": null,
6
   "metadata": {},
7
   "outputs": [],
8
   "source": [
9
    "import re\n",
10
    "import csv\n",
11
    "import pandas\n",
12
    "import sqlite3\n",
13
    "import random\n",
14
    "import json\n",
15
    "import itertools\n",
16
    "import numpy as np\n",
17
    "from sumeval.metrics.rouge import RougeCalculator\n",
18
    "rouge = RougeCalculator(stopwords=False, lang=\"en\")\n",
19
    "\n",
20
    "from utils import *"
21
   ]
22
  },
23
  {
24
   "cell_type": "code",
25
   "execution_count": null,
26
   "metadata": {},
27
   "outputs": [],
28
   "source": [
29
    "db_file = 'mimic_db/mimic.db'\n",
30
    "model = query(db_file)\n",
31
    "(db_meta, db_tabs, db_head) = model._load_db(db_file)"
32
   ]
33
  },
34
  {
35
   "cell_type": "code",
36
   "execution_count": null,
37
   "metadata": {},
38
   "outputs": [],
39
   "source": [
40
    "fp = open('mimic_db/lookup.json', 'r')\n",
41
    "lookup = json.load(fp)\n",
42
    "fp.close()"
43
   ]
44
  },
45
  {
46
   "cell_type": "code",
47
   "execution_count": null,
48
   "metadata": {},
49
   "outputs": [],
50
   "source": [
51
    "headerDic = []\n",
52
    "for tb in lookup:\n",
53
    "    for hd in lookup[tb]:\n",
54
    "        headerDic.append('.'.join([tb,hd]).lower())\n",
55
    "# print(headerDic)"
56
   ]
57
  },
58
  {
59
   "cell_type": "code",
60
   "execution_count": null,
61
   "metadata": {},
62
   "outputs": [],
63
   "source": [
64
    "fp = open('generated_sql/output.json', 'r')\n",
65
    "cnt = 0\n",
66
    "sql_rec = []\n",
67
    "for line in fp:\n",
68
    "    line = json.loads(line)\n",
69
    "    pred = re.split('<stop>', line['sql_pred'])[0]\n",
70
    "    ttt = line['sql_gold']\n",
71
    "\n",
72
    "#     print(pred)\n",
73
    "    predArr = re.split(\"where\", pred)\n",
74
    "    predAgg = re.split(\"\\s\", predArr[0])\n",
75
    "    predAgg = list(filter(None, predAgg))\n",
76
    "    predAgg2 = []\n",
77
    "    for k in range(len(predAgg)-1):\n",
78
    "        if predAgg[k] in headerDic and predAgg[k+1] in headerDic:\n",
79
    "            predAgg2.append(predAgg[k] + ',')\n",
80
    "        else:\n",
81
    "            predAgg2.append(predAgg[k])\n",
82
    "    predAgg2.append(predAgg[-1])\n",
83
    "    predAgg = ' '.join(predAgg2)\n",
84
    "    \n",
85
    "    predCon = re.split(\"and\", predArr[1])\n",
86
    "    predConNew = []\n",
87
    "    k = 0\n",
88
    "    while k < len(predCon):\n",
89
    "        if \"=\" in predCon[k] or \"<\" in predCon[k] or \">\" in predCon[k]:\n",
90
    "            predConNew.append(predCon[k])\n",
91
    "        else:\n",
92
    "            predConNew[-1] += \" and \" + predCon[k]\n",
93
    "            k += 1\n",
94
    "        k += 1\n",
95
    "    for k in range(len(predConNew)):\n",
96
    "        if \"=\" in predConNew[k]:\n",
97
    "            conOp = \"=\"\n",
98
    "        if \">\" in predConNew[k]:\n",
99
    "            conOp = \">\"\n",
100
    "        if \"<\" in predConNew[k]:\n",
101
    "            conOp = \"<\"\n",
102
    "        if \"<=\" in predConNew[k]:\n",
103
    "            conOp = \"<=\"\n",
104
    "        if \">=\" in predConNew[k]:\n",
105
    "            conOp = \">=\"\n",
106
    "        conVal = re.split(\"=|<|>\", predConNew[k])\n",
107
    "        conVal = list(filter(None, conVal))\n",
108
    "        conCol = conVal[0]\n",
109
    "        conColArr = re.split('\\.|\\s', conCol)\n",
110
    "        conColArr = list(filter(None, conColArr))\n",
111
    "        try:\n",
112
    "            pool_ = lookup[conColArr[0].upper()][conColArr[1].upper()]\n",
113
    "        except:\n",
114
    "            sql_rec.append([\"Error\", ttt])\n",
115
    "            continue\n",
116
    "        conVal = re.split('\"|\\s', conVal[-1])\n",
117
    "        conVal = list(filter(None, conVal))\n",
118
    "        conVal = ' '.join(conVal)\n",
119
    "        predConNew[k] = conCol + conOp + ' \"' + conVal + '\"'\n",
120
    "\n",
121
    "    pred = predAgg + ' where ' + ' and '.join(predConNew)\n",
122
    "    pred = re.split(\"\\s\", pred)\n",
123
    "    pred = list(filter(None, pred))\n",
124
    "    pred = \" \".join(pred)\n",
125
    "#     print(pred)\n",
126
    "#     print(ttt)\n",
127
    "#     print()\n",
128
    "    sql_rec.append([pred, ttt])\n",
129
    "    try:\n",
130
    "        myres = model.execute_sql(pred).fetchall()\n",
131
    "        myres = list({k[0]: {} for k in myres if not k[0] == None})\n",
132
    "        cnt += 1\n",
133
    "    except:\n",
134
    "        pass\n",
135
    "    \n",
136
    "fp.close()\n",
137
    "print(cnt)"
138
   ]
139
  },
140
  {
141
   "cell_type": "markdown",
142
   "metadata": {},
143
   "source": [
144
    "## Logic Form Accuracy:"
145
   ]
146
  },
147
  {
148
   "cell_type": "code",
149
   "execution_count": null,
150
   "metadata": {
151
    "scrolled": true
152
   },
153
   "outputs": [],
154
   "source": [
155
    "cnt = 0\n",
156
    "for itm in sql_rec:\n",
157
    "    arr = re.split(',|\\s', itm[0].lower())\n",
158
    "    arr = list(filter(None, arr))\n",
159
    "    arr = ' '.join(arr)\n",
160
    "    if arr == itm[1]:\n",
161
    "        cnt += 1\n",
162
    "print('Logic Form Accuracy: {}'.format(cnt/1000))"
163
   ]
164
  },
165
  {
166
   "cell_type": "code",
167
   "execution_count": null,
168
   "metadata": {},
169
   "outputs": [],
170
   "source": [
171
    "fp = open('generated_sql/output.json', 'r')\n",
172
    "cnt = 0\n",
173
    "lblb = 0\n",
174
    "sql_rec2 = []\n",
175
    "for line in fp:\n",
176
    "    line = json.loads(line)\n",
177
    "    pred = re.split('<stop>', line['sql_pred'])[0]\n",
178
    "    ttt = line['sql_gold']\n",
179
    "\n",
180
    "#     print(pred)\n",
181
    "    predArr = re.split(\"where\", pred)\n",
182
    "    predAgg = re.split(\"\\s\", predArr[0])\n",
183
    "    predAgg = list(filter(None, predAgg))\n",
184
    "    predAgg2 = []\n",
185
    "    for k in range(len(predAgg)-1):\n",
186
    "        if predAgg[k] in headerDic and predAgg[k+1] in headerDic:\n",
187
    "            predAgg2.append(predAgg[k] + ',')\n",
188
    "        else:\n",
189
    "            predAgg2.append(predAgg[k])\n",
190
    "    predAgg2.append(predAgg[-1])\n",
191
    "    predAgg = ' '.join(predAgg2)\n",
192
    "    \n",
193
    "    predCon = re.split(\"and\", predArr[1])\n",
194
    "    predConNew = []\n",
195
    "    k = 0\n",
196
    "    while k < len(predCon):\n",
197
    "        if \"=\" in predCon[k] or \"<\" in predCon[k] or \">\" in predCon[k]:\n",
198
    "            predConNew.append(predCon[k])\n",
199
    "        else:\n",
200
    "            predConNew[-1] += \" and \" + predCon[k]\n",
201
    "            k += 1\n",
202
    "        k += 1\n",
203
    "    for k in range(len(predConNew)):\n",
204
    "        if \"=\" in predConNew[k]:\n",
205
    "            conOp = \"=\"\n",
206
    "        if \">\" in predConNew[k]:\n",
207
    "            conOp = \">\"\n",
208
    "        if \"<\" in predConNew[k]:\n",
209
    "            conOp = \"<\"\n",
210
    "        if \"<=\" in predConNew[k]:\n",
211
    "            conOp = \"<=\"\n",
212
    "        if \">=\" in predConNew[k]:\n",
213
    "            conOp = \">=\"\n",
214
    "        conVal = re.split(\"=|<|>\", predConNew[k])\n",
215
    "        conVal = list(filter(None, conVal))\n",
216
    "        conCol = conVal[0]\n",
217
    "        conColArr = re.split('\\.|\\s', conCol)\n",
218
    "        conColArr = list(filter(None, conColArr))\n",
219
    "        try:\n",
220
    "            pool_ = lookup[conColArr[0].upper()][conColArr[1].upper()]\n",
221
    "        except:\n",
222
    "            sql_rec2.append([\"Error\", ttt])\n",
223
    "            lblb = 1\n",
224
    "            break\n",
225
    "        conVal = re.split('\"|\\s', conVal[-1])\n",
226
    "        conVal = list(filter(None, conVal))\n",
227
    "        conVal = ' '.join(conVal)\n",
228
    "        try:\n",
229
    "            int(conVal)\n",
230
    "            predConNew[k] = conCol + conOp + ' \"' + conVal + '\"'\n",
231
    "        except:\n",
232
    "            predConNew[k] = 'lower(' + conCol + ')' + conOp + ' \"' + conVal + '\"'\n",
233
    "    \n",
234
    "    if lblb ==1:\n",
235
    "        lblb = 0\n",
236
    "        continue\n",
237
    "    pred = predAgg + ' where ' + ' and '.join(predConNew)\n",
238
    "    pred = re.split(\"\\s\", pred)\n",
239
    "    pred = list(filter(None, pred))\n",
240
    "    pred = \" \".join(pred)\n",
241
    "#     print(pred)\n",
242
    "#     print(ttt)\n",
243
    "#     print()\n",
244
    "    sql_rec2.append([pred, ttt])\n",
245
    "    try:\n",
246
    "        myres = model.execute_sql(pred).fetchall()\n",
247
    "        myres = list({k[0]: {} for k in myres if not k[0] == None})\n",
248
    "        cnt += 1\n",
249
    "    except:\n",
250
    "        pass\n",
251
    "    \n",
252
    "fp.close()\n",
253
    "print(cnt)"
254
   ]
255
  },
256
  {
257
   "cell_type": "code",
258
   "execution_count": null,
259
   "metadata": {},
260
   "outputs": [],
261
   "source": [
262
    "fp = open(\"../mimicsql_data/mimicsql_natural/test.json\", 'r')\n",
263
    "cnt = 0\n",
264
    "for line in fp:\n",
265
    "    data = json.loads(line)\n",
266
    "    sql_rec2[cnt][1] = data['sql']\n",
267
    "    cnt += 1\n",
268
    "fp.close()\n",
269
    "print(cnt)"
270
   ]
271
  },
272
  {
273
   "cell_type": "markdown",
274
   "metadata": {},
275
   "source": [
276
    "## Execution Accuracy:"
277
   ]
278
  },
279
  {
280
   "cell_type": "code",
281
   "execution_count": null,
282
   "metadata": {},
283
   "outputs": [],
284
   "source": [
285
    "cnt = 0\n",
286
    "k = 0\n",
287
    "for itm in sql_rec2:\n",
288
    "    [pred, ttt] = itm\n",
289
    "    try:\n",
290
    "        outPred = model.execute_sql(pred).fetchall()\n",
291
    "        outTtt = model.execute_sql(ttt).fetchall()\n",
292
    "    except:\n",
293
    "        k += 1\n",
294
    "        continue\n",
295
    "    if outPred == outTtt:\n",
296
    "        cnt += 1\n",
297
    "#     else:\n",
298
    "#         if sql_rec[k][0] == sql_rec[k][1].lower():\n",
299
    "#             print(pred)\n",
300
    "#             print(sql_rec[k][0])\n",
301
    "#             print(ttt.lower())\n",
302
    "#             print(ttt)\n",
303
    "#             print(outPred)\n",
304
    "#             print(outTtt)\n",
305
    "#             print()\n",
306
    "    k += 1\n",
307
    "print('Execution Accuracy: {}'.format(cnt/1000))"
308
   ]
309
  }
310
 ],
311
 "metadata": {
312
  "kernelspec": {
313
   "display_name": "Python 3",
314
   "language": "python",
315
   "name": "python3"
316
  },
317
  "language_info": {
318
   "codemirror_mode": {
319
    "name": "ipython",
320
    "version": 3
321
   },
322
   "file_extension": ".py",
323
   "mimetype": "text/x-python",
324
   "name": "python",
325
   "nbconvert_exporter": "python",
326
   "pygments_lexer": "ipython3",
327
   "version": "3.6.7"
328
  }
329
 },
330
 "nbformat": 4,
331
 "nbformat_minor": 2
332
}