a b/development/qa-server/DPR_extractive_QA.ipynb
1
{
2
  "nbformat": 4,
3
  "nbformat_minor": 0,
4
  "metadata": {
5
    "colab": {
6
      "name": "DPR_extractive_QA.ipynb",
7
      "provenance": [],
8
      "collapsed_sections": []
9
    },
10
    "kernelspec": {
11
      "display_name": "Python 3",
12
      "name": "python3"
13
    },
14
    "language_info": {
15
      "name": "python"
16
    },
17
    "widgets": {
18
      "application/vnd.jupyter.widget-state+json": {
19
        "adb6799fa25e4bc7adf1212f63f04973": {
20
          "model_module": "@jupyter-widgets/controls",
21
          "model_name": "HBoxModel",
22
          "model_module_version": "1.5.0",
23
          "state": {
24
            "_view_name": "HBoxView",
25
            "_dom_classes": [],
26
            "_model_name": "HBoxModel",
27
            "_view_module": "@jupyter-widgets/controls",
28
            "_model_module_version": "1.5.0",
29
            "_view_count": null,
30
            "_view_module_version": "1.5.0",
31
            "box_style": "",
32
            "layout": "IPY_MODEL_d79b5216eea84819adca4b226b63cdf2",
33
            "_model_module": "@jupyter-widgets/controls",
34
            "children": [
35
              "IPY_MODEL_a31c486541af49d29a33720649112c4d",
36
              "IPY_MODEL_daf91846ad164d55a741bf9c7894d7de",
37
              "IPY_MODEL_6614bcfbf2e64961b8d4a36a8fa0942f"
38
            ]
39
          }
40
        },
41
        "d79b5216eea84819adca4b226b63cdf2": {
42
          "model_module": "@jupyter-widgets/base",
43
          "model_name": "LayoutModel",
44
          "model_module_version": "1.2.0",
45
          "state": {
46
            "_view_name": "LayoutView",
47
            "grid_template_rows": null,
48
            "right": null,
49
            "justify_content": null,
50
            "_view_module": "@jupyter-widgets/base",
51
            "overflow": null,
52
            "_model_module_version": "1.2.0",
53
            "_view_count": null,
54
            "flex_flow": null,
55
            "width": null,
56
            "min_width": null,
57
            "border": null,
58
            "align_items": null,
59
            "bottom": null,
60
            "_model_module": "@jupyter-widgets/base",
61
            "top": null,
62
            "grid_column": null,
63
            "overflow_y": null,
64
            "overflow_x": null,
65
            "grid_auto_flow": null,
66
            "grid_area": null,
67
            "grid_template_columns": null,
68
            "flex": null,
69
            "_model_name": "LayoutModel",
70
            "justify_items": null,
71
            "grid_row": null,
72
            "max_height": null,
73
            "align_content": null,
74
            "visibility": null,
75
            "align_self": null,
76
            "height": null,
77
            "min_height": null,
78
            "padding": null,
79
            "grid_auto_rows": null,
80
            "grid_gap": null,
81
            "max_width": null,
82
            "order": null,
83
            "_view_module_version": "1.2.0",
84
            "grid_template_areas": null,
85
            "object_position": null,
86
            "object_fit": null,
87
            "grid_auto_columns": null,
88
            "margin": null,
89
            "display": null,
90
            "left": null
91
          }
92
        },
93
        "a31c486541af49d29a33720649112c4d": {
94
          "model_module": "@jupyter-widgets/controls",
95
          "model_name": "HTMLModel",
96
          "model_module_version": "1.5.0",
97
          "state": {
98
            "_view_name": "HTMLView",
99
            "style": "IPY_MODEL_5295e24c499249ee814331ddd6d0f984",
100
            "_dom_classes": [],
101
            "description": "",
102
            "_model_name": "HTMLModel",
103
            "placeholder": "​",
104
            "_view_module": "@jupyter-widgets/controls",
105
            "_model_module_version": "1.5.0",
106
            "value": "Create embeddings: 100%",
107
            "_view_count": null,
108
            "_view_module_version": "1.5.0",
109
            "description_tooltip": null,
110
            "_model_module": "@jupyter-widgets/controls",
111
            "layout": "IPY_MODEL_8eb53342f51c4307b5c2f98de2baf3e0"
112
          }
113
        },
114
        "daf91846ad164d55a741bf9c7894d7de": {
115
          "model_module": "@jupyter-widgets/controls",
116
          "model_name": "FloatProgressModel",
117
          "model_module_version": "1.5.0",
118
          "state": {
119
            "_view_name": "ProgressView",
120
            "style": "IPY_MODEL_1e85d3a19e6941c08a43fb531cf407dc",
121
            "_dom_classes": [],
122
            "description": "",
123
            "_model_name": "FloatProgressModel",
124
            "bar_style": "",
125
            "max": 7344,
126
            "_view_module": "@jupyter-widgets/controls",
127
            "_model_module_version": "1.5.0",
128
            "value": 7344,
129
            "_view_count": null,
130
            "_view_module_version": "1.5.0",
131
            "orientation": "horizontal",
132
            "min": 0,
133
            "description_tooltip": null,
134
            "_model_module": "@jupyter-widgets/controls",
135
            "layout": "IPY_MODEL_982358c88a9a44b48ae915b4e3b9b46b"
136
          }
137
        },
138
        "6614bcfbf2e64961b8d4a36a8fa0942f": {
139
          "model_module": "@jupyter-widgets/controls",
140
          "model_name": "HTMLModel",
141
          "model_module_version": "1.5.0",
142
          "state": {
143
            "_view_name": "HTMLView",
144
            "style": "IPY_MODEL_1c518d3f76424c609b6f085ae15de0c8",
145
            "_dom_classes": [],
146
            "description": "",
147
            "_model_name": "HTMLModel",
148
            "placeholder": "​",
149
            "_view_module": "@jupyter-widgets/controls",
150
            "_model_module_version": "1.5.0",
151
            "value": " 7328/7344 [03:48<00:00, 32.26 Docs/s]",
152
            "_view_count": null,
153
            "_view_module_version": "1.5.0",
154
            "description_tooltip": null,
155
            "_model_module": "@jupyter-widgets/controls",
156
            "layout": "IPY_MODEL_b2aaec2787ff4e17bddb01515d5d7068"
157
          }
158
        },
159
        "5295e24c499249ee814331ddd6d0f984": {
160
          "model_module": "@jupyter-widgets/controls",
161
          "model_name": "DescriptionStyleModel",
162
          "model_module_version": "1.5.0",
163
          "state": {
164
            "_view_name": "StyleView",
165
            "_model_name": "DescriptionStyleModel",
166
            "description_width": "",
167
            "_view_module": "@jupyter-widgets/base",
168
            "_model_module_version": "1.5.0",
169
            "_view_count": null,
170
            "_view_module_version": "1.2.0",
171
            "_model_module": "@jupyter-widgets/controls"
172
          }
173
        },
174
        "8eb53342f51c4307b5c2f98de2baf3e0": {
175
          "model_module": "@jupyter-widgets/base",
176
          "model_name": "LayoutModel",
177
          "model_module_version": "1.2.0",
178
          "state": {
179
            "_view_name": "LayoutView",
180
            "grid_template_rows": null,
181
            "right": null,
182
            "justify_content": null,
183
            "_view_module": "@jupyter-widgets/base",
184
            "overflow": null,
185
            "_model_module_version": "1.2.0",
186
            "_view_count": null,
187
            "flex_flow": null,
188
            "width": null,
189
            "min_width": null,
190
            "border": null,
191
            "align_items": null,
192
            "bottom": null,
193
            "_model_module": "@jupyter-widgets/base",
194
            "top": null,
195
            "grid_column": null,
196
            "overflow_y": null,
197
            "overflow_x": null,
198
            "grid_auto_flow": null,
199
            "grid_area": null,
200
            "grid_template_columns": null,
201
            "flex": null,
202
            "_model_name": "LayoutModel",
203
            "justify_items": null,
204
            "grid_row": null,
205
            "max_height": null,
206
            "align_content": null,
207
            "visibility": null,
208
            "align_self": null,
209
            "height": null,
210
            "min_height": null,
211
            "padding": null,
212
            "grid_auto_rows": null,
213
            "grid_gap": null,
214
            "max_width": null,
215
            "order": null,
216
            "_view_module_version": "1.2.0",
217
            "grid_template_areas": null,
218
            "object_position": null,
219
            "object_fit": null,
220
            "grid_auto_columns": null,
221
            "margin": null,
222
            "display": null,
223
            "left": null
224
          }
225
        },
226
        "1e85d3a19e6941c08a43fb531cf407dc": {
227
          "model_module": "@jupyter-widgets/controls",
228
          "model_name": "ProgressStyleModel",
229
          "model_module_version": "1.5.0",
230
          "state": {
231
            "_view_name": "StyleView",
232
            "_model_name": "ProgressStyleModel",
233
            "description_width": "",
234
            "_view_module": "@jupyter-widgets/base",
235
            "_model_module_version": "1.5.0",
236
            "_view_count": null,
237
            "_view_module_version": "1.2.0",
238
            "bar_color": null,
239
            "_model_module": "@jupyter-widgets/controls"
240
          }
241
        },
242
        "982358c88a9a44b48ae915b4e3b9b46b": {
243
          "model_module": "@jupyter-widgets/base",
244
          "model_name": "LayoutModel",
245
          "model_module_version": "1.2.0",
246
          "state": {
247
            "_view_name": "LayoutView",
248
            "grid_template_rows": null,
249
            "right": null,
250
            "justify_content": null,
251
            "_view_module": "@jupyter-widgets/base",
252
            "overflow": null,
253
            "_model_module_version": "1.2.0",
254
            "_view_count": null,
255
            "flex_flow": null,
256
            "width": null,
257
            "min_width": null,
258
            "border": null,
259
            "align_items": null,
260
            "bottom": null,
261
            "_model_module": "@jupyter-widgets/base",
262
            "top": null,
263
            "grid_column": null,
264
            "overflow_y": null,
265
            "overflow_x": null,
266
            "grid_auto_flow": null,
267
            "grid_area": null,
268
            "grid_template_columns": null,
269
            "flex": null,
270
            "_model_name": "LayoutModel",
271
            "justify_items": null,
272
            "grid_row": null,
273
            "max_height": null,
274
            "align_content": null,
275
            "visibility": null,
276
            "align_self": null,
277
            "height": null,
278
            "min_height": null,
279
            "padding": null,
280
            "grid_auto_rows": null,
281
            "grid_gap": null,
282
            "max_width": null,
283
            "order": null,
284
            "_view_module_version": "1.2.0",
285
            "grid_template_areas": null,
286
            "object_position": null,
287
            "object_fit": null,
288
            "grid_auto_columns": null,
289
            "margin": null,
290
            "display": null,
291
            "left": null
292
          }
293
        },
294
        "1c518d3f76424c609b6f085ae15de0c8": {
295
          "model_module": "@jupyter-widgets/controls",
296
          "model_name": "DescriptionStyleModel",
297
          "model_module_version": "1.5.0",
298
          "state": {
299
            "_view_name": "StyleView",
300
            "_model_name": "DescriptionStyleModel",
301
            "description_width": "",
302
            "_view_module": "@jupyter-widgets/base",
303
            "_model_module_version": "1.5.0",
304
            "_view_count": null,
305
            "_view_module_version": "1.2.0",
306
            "_model_module": "@jupyter-widgets/controls"
307
          }
308
        },
309
        "b2aaec2787ff4e17bddb01515d5d7068": {
310
          "model_module": "@jupyter-widgets/base",
311
          "model_name": "LayoutModel",
312
          "model_module_version": "1.2.0",
313
          "state": {
314
            "_view_name": "LayoutView",
315
            "grid_template_rows": null,
316
            "right": null,
317
            "justify_content": null,
318
            "_view_module": "@jupyter-widgets/base",
319
            "overflow": null,
320
            "_model_module_version": "1.2.0",
321
            "_view_count": null,
322
            "flex_flow": null,
323
            "width": null,
324
            "min_width": null,
325
            "border": null,
326
            "align_items": null,
327
            "bottom": null,
328
            "_model_module": "@jupyter-widgets/base",
329
            "top": null,
330
            "grid_column": null,
331
            "overflow_y": null,
332
            "overflow_x": null,
333
            "grid_auto_flow": null,
334
            "grid_area": null,
335
            "grid_template_columns": null,
336
            "flex": null,
337
            "_model_name": "LayoutModel",
338
            "justify_items": null,
339
            "grid_row": null,
340
            "max_height": null,
341
            "align_content": null,
342
            "visibility": null,
343
            "align_self": null,
344
            "height": null,
345
            "min_height": null,
346
            "padding": null,
347
            "grid_auto_rows": null,
348
            "grid_gap": null,
349
            "max_width": null,
350
            "order": null,
351
            "_view_module_version": "1.2.0",
352
            "grid_template_areas": null,
353
            "object_position": null,
354
            "object_fit": null,
355
            "grid_auto_columns": null,
356
            "margin": null,
357
            "display": null,
358
            "left": null
359
          }
360
        }
361
      }
362
    },
363
    "accelerator": "GPU"
364
  },
365
  "cells": [
366
    {
367
      "cell_type": "code",
368
      "metadata": {
369
        "id": "Y0R8IwSCWYVk"
370
      },
371
      "source": [
372
        "!pip install -q -U transformers[sentencepiece] rouge git+https://github.com/deepset-ai/haystack.git grpcio-tools==1.34.1 spacy"
373
      ],
374
      "execution_count": 11,
375
      "outputs": []
376
    },
377
    {
378
      "cell_type": "code",
379
      "metadata": {
380
        "id": "1EHR4i9FiCAa"
381
      },
382
      "source": [
383
        "import spacy\n",
384
        "import nltk\n",
385
        "import json\n",
386
        "from tqdm import tqdm\n",
387
        "import pandas as pd \n",
388
        "from rouge import Rouge\n",
389
        "from pprint import pprint\n",
390
        "from typing import List\n",
391
        "from haystack import Document\n",
392
        "from haystack.reader import TransformersReader\n",
393
        "from haystack.pipeline import ExtractiveQAPipeline \n",
394
        "from haystack.retriever.dense import DensePassageRetriever \n",
395
        "from haystack.document_store.faiss import FAISSDocumentStore\n",
396
        "from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction"
397
      ],
398
      "execution_count": 12,
399
      "outputs": []
400
    },
401
    {
402
      "cell_type": "code",
403
      "metadata": {
404
        "colab": {
405
          "base_uri": "https://localhost:8080/"
406
        },
407
        "id": "gyHgtCYAXEZb",
408
        "outputId": "824bdab5-4c5b-4064-8efd-b0bfc2798827"
409
      },
410
      "source": [
411
        "!spacy download en_core_web_md \n",
412
        "!spacy link en_core_web_md en"
413
      ],
414
      "execution_count": 13,
415
      "outputs": [
416
        {
417
          "output_type": "stream",
418
          "name": "stdout",
419
          "text": [
420
            "Collecting en-core-web-md==3.1.0\n",
421
            "  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.1.0/en_core_web_md-3.1.0-py3-none-any.whl (45.4 MB)\n",
422
            "\u001b[K     |████████████████████████████████| 45.4 MB 17 kB/s \n",
423
            "\u001b[?25hRequirement already satisfied: spacy<3.2.0,>=3.1.0 in /usr/local/lib/python3.7/dist-packages (from en-core-web-md==3.1.0) (3.1.3)\n",
424
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (21.0)\n",
425
            "Requirement already satisfied: typer<0.5.0,>=0.3.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (0.4.0)\n",
426
            "Requirement already satisfied: pydantic!=1.8,!=1.8.1,<1.9.0,>=1.7.4 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (1.8.2)\n",
427
            "Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.0.5)\n",
428
            "Requirement already satisfied: blis<0.8.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (0.4.1)\n",
429
            "Requirement already satisfied: wasabi<1.1.0,>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (0.8.2)\n",
430
            "Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (1.19.5)\n",
431
            "Requirement already satisfied: pathy>=0.3.5 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (0.6.0)\n",
432
            "Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (4.62.2)\n",
433
            "Requirement already satisfied: thinc<8.1.0,>=8.0.9 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (8.0.10)\n",
434
            "Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.8 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (3.0.8)\n",
435
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.11.3)\n",
436
            "Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (3.0.5)\n",
437
            "Requirement already satisfied: srsly<3.0.0,>=2.4.1 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.4.1)\n",
438
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (57.4.0)\n",
439
            "Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.0.6)\n",
440
            "Requirement already satisfied: typing-extensions<4.0.0.0,>=3.7.4 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (3.7.4.3)\n",
441
            "Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.23.0)\n",
442
            "Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (1.0.5)\n",
443
            "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from catalogue<2.1.0,>=2.0.6->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (3.5.0)\n",
444
            "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.4.7)\n",
445
            "Requirement already satisfied: smart-open<6.0.0,>=5.0.0 in /usr/local/lib/python3.7/dist-packages (from pathy>=0.3.5->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (5.2.1)\n",
446
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (1.25.11)\n",
447
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2021.5.30)\n",
448
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.10)\n",
449
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (3.0.4)\n",
450
            "Requirement already satisfied: click<9.0.0,>=7.1.1 in /usr/local/lib/python3.7/dist-packages (from typer<0.5.0,>=0.3.0->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (7.1.2)\n",
451
            "Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.0.1)\n",
452
            "\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n",
453
            "You can now load the package via spacy.load('en_core_web_md')\n",
454
            "\u001b[31mDeprecationWarning: The command link is deprecated.\u001b[0m\n",
455
            "\u001b[38;5;3m⚠ As of spaCy v3.0, model symlinks are not supported anymore. You can\n",
456
            "load trained pipeline packages using their full names or from a directory\n",
457
            "path.\u001b[0m\n"
458
          ]
459
        }
460
      ]
461
    },
462
    {
463
      "cell_type": "code",
464
      "metadata": {
465
        "colab": {
466
          "base_uri": "https://localhost:8080/"
467
        },
468
        "id": "K7ipay4bguxX",
469
        "outputId": "a54705c6-fea4-4533-a4c9-638e46d49c89"
470
      },
471
      "source": [
472
        "nltk.download('punkt')"
473
      ],
474
      "execution_count": 22,
475
      "outputs": [
476
        {
477
          "output_type": "stream",
478
          "name": "stdout",
479
          "text": [
480
            "[nltk_data] Downloading package punkt to /root/nltk_data...\n",
481
            "[nltk_data]   Unzipping tokenizers/punkt.zip.\n"
482
          ]
483
        },
484
        {
485
          "output_type": "execute_result",
486
          "data": {
487
            "text/plain": [
488
              "True"
489
            ]
490
          },
491
          "metadata": {},
492
          "execution_count": 22
493
        }
494
      ]
495
    },
496
    {
497
      "cell_type": "code",
498
      "metadata": {
499
        "colab": {
500
          "base_uri": "https://localhost:8080/"
501
        },
502
        "id": "fhLYC0ivc0A_",
503
        "outputId": "a48b2a6c-f1f9-4ab3-e14f-f777c295901c"
504
      },
505
      "source": [
506
        "import spacy\n",
507
        "# nlp = spacy.load('en_core_web_md')\n",
508
        "nlp = English() \n",
509
        "nlp.add_pipe(\"sentencizer\")"
510
      ],
511
      "execution_count": 61,
512
      "outputs": [
513
        {
514
          "output_type": "execute_result",
515
          "data": {
516
            "text/plain": [
517
              "<spacy.pipeline.sentencizer.Sentencizer at 0x7f4ea4ea7640>"
518
            ]
519
          },
520
          "metadata": {},
521
          "execution_count": 61
522
        }
523
      ]
524
    },
525
    {
526
      "cell_type": "code",
527
      "metadata": {
528
        "colab": {
529
          "base_uri": "https://localhost:8080/"
530
        },
531
        "id": "ZkXrmWA1ViH-",
532
        "outputId": "bf834313-8fff-49bf-9cce-cab06f0f7d1d"
533
      },
534
      "source": [
535
        "from google.colab import drive\n",
536
        "drive.mount('/content/drive')"
537
      ],
538
      "execution_count": 15,
539
      "outputs": [
540
        {
541
          "output_type": "stream",
542
          "name": "stdout",
543
          "text": [
544
            "Mounted at /content/drive\n"
545
          ]
546
        }
547
      ]
548
    },
549
    {
550
      "cell_type": "code",
551
      "metadata": {
552
        "id": "S8XOUzQpwLw9"
553
      },
554
      "source": [
555
        "with open('drive/MyDrive/qa_test.json', \"r\") as f:\n",
556
        "    qa = json.loads(f.read())['data']\n",
557
        "\n",
558
        "df = pd.read_csv('drive/MyDrive/ex-QA.csv', index_col=0)\n",
559
        "df = df.replace(r'\\n',' ', regex=True) "
560
      ],
561
      "execution_count": 16,
562
      "outputs": []
563
    },
564
    {
565
      "cell_type": "code",
566
      "metadata": {
567
        "id": "fbsJedAQ1mBF"
568
      },
569
      "source": [
570
        "titles = list(df[\"title\"].values)\n",
571
        "texts  = list(df[\"text\"].values)\n",
572
        "documents: List[Document] = []\n",
573
        " \n",
574
        "for title, text in zip(titles, texts):\n",
575
        "    documents.append(\n",
576
        "        Document(\n",
577
        "            text=text,\n",
578
        "            meta={\n",
579
        "                \"name\": title or \"\"\n",
580
        "            }\n",
581
        "        )\n",
582
        "    )"
583
      ],
584
      "execution_count": 17,
585
      "outputs": []
586
    },
587
    {
588
      "cell_type": "code",
589
      "metadata": {
590
        "colab": {
591
          "base_uri": "https://localhost:8080/",
592
          "height": 108,
593
          "referenced_widgets": [
594
            "adb6799fa25e4bc7adf1212f63f04973",
595
            "d79b5216eea84819adca4b226b63cdf2",
596
            "a31c486541af49d29a33720649112c4d",
597
            "daf91846ad164d55a741bf9c7894d7de",
598
            "6614bcfbf2e64961b8d4a36a8fa0942f",
599
            "5295e24c499249ee814331ddd6d0f984",
600
            "8eb53342f51c4307b5c2f98de2baf3e0",
601
            "1e85d3a19e6941c08a43fb531cf407dc",
602
            "982358c88a9a44b48ae915b4e3b9b46b",
603
            "1c518d3f76424c609b6f085ae15de0c8",
604
            "b2aaec2787ff4e17bddb01515d5d7068"
605
          ]
606
        },
607
        "id": "pDemzc4-kPww",
608
        "outputId": "85a2e13c-4c01-4238-ce7b-576247e989a6"
609
      },
610
      "source": [
611
        "document_store = FAISSDocumentStore(\n",
612
        "    faiss_index_factory_str=\"Flat\",\n",
613
        "    return_embedding=True\n",
614
        ")\n",
615
        "\n",
616
        "retriever = DensePassageRetriever(\n",
617
        "    document_store=document_store,\n",
618
        "    query_embedding_model=\"facebook/dpr-question_encoder-single-nq-base\",\n",
619
        "    passage_embedding_model=\"facebook/dpr-ctx_encoder-single-nq-base\",\n",
620
        "    use_gpu=True,\n",
621
        "    embed_title=True,\n",
622
        ")\n",
623
        "\n",
624
        "# retriever = DensePassageRetriever(\n",
625
        "#     document_store=document_store,\n",
626
        "#     query_embedding_model=\"drive/MyDrive/bert-large-finetuned\",\n",
627
        "#     passage_embedding_model=\"drive/MyDrive/bert-large-finetuned\",\n",
628
        "#     use_gpu=True,\n",
629
        "#     embed_title=True,\n",
630
        "# )\n",
631
        "\n",
632
        "document_store.delete_documents()\n",
633
        "document_store.write_documents(documents)\n",
634
        "document_store.update_embeddings(\n",
635
        "    retriever=retriever\n",
636
        ")"
637
      ],
638
      "execution_count": 85,
639
      "outputs": [
640
        {
641
          "output_type": "stream",
642
          "name": "stderr",
643
          "text": [
644
            "09/24/2021 21:45:40 - WARNING - farm.utils -   Failed to log params: Changing param values is not allowed. Param with key='lm1_name' was already logged with value='drive/MyDrive/squeeze-bert-finetuned' for run ID='57f858c4db2047c3a6013ad6593d2c9b'. Attempted logging new value 'facebook/dpr-question_encoder-single-nq-base'.\n",
645
            "09/24/2021 21:46:02 - INFO - haystack.document_store.faiss -   Updating embeddings for 7330 docs...\n",
646
            "Updating Embedding:   0%|          | 0/7330 [00:00<?, ? docs/s]"
647
          ]
648
        },
649
        {
650
          "output_type": "display_data",
651
          "data": {
652
            "application/vnd.jupyter.widget-view+json": {
653
              "model_id": "adb6799fa25e4bc7adf1212f63f04973",
654
              "version_minor": 0,
655
              "version_major": 2
656
            },
657
            "text/plain": [
658
              "Create embeddings:   0%|          | 0/7344 [00:00<?, ? Docs/s]"
659
            ]
660
          },
661
          "metadata": {}
662
        },
663
        {
664
          "output_type": "stream",
665
          "name": "stderr",
666
          "text": [
667
            "Documents Processed: 10000 docs [03:58, 41.92 docs/s]\n"
668
          ]
669
        }
670
      ]
671
    },
672
    {
673
      "cell_type": "code",
674
      "metadata": {
675
        "id": "TifruTY50Lsn"
676
      },
677
      "source": [
678
        "# reader = TransformersReader(model_name_or_path=\"ahotrod/albert_xxlargev1_squad2_512\", use_gpu=0)\n",
679
        "# reader = TransformersReader(model_name_or_path=\"bert-large-uncased-whole-word-masking-finetuned-squad\", use_gpu=0)\n",
680
        "\n",
681
        "reader = TransformersReader(model_name_or_path=\"ktrapeznikov/albert-xlarge-v2-squad-v2\", \n",
682
        "                            context_window_size=70,\n",
683
        "                            max_seq_len=256,\n",
684
        "                            doc_stride=128,\n",
685
        "                            use_gpu=0)\n",
686
        "\n",
687
        "# reader = TransformersReader(model_name_or_path=\"drive/MyDrive/bert_basefi_qafi\", \n",
688
        "#                             context_window_size=70,\n",
689
        "#                             max_seq_len=256,\n",
690
        "#                             doc_stride=128,\n",
691
        "#                             use_gpu=0)\n",
692
        "\n",
693
        "pipe = ExtractiveQAPipeline(reader, retriever)"
694
      ],
695
      "execution_count": 165,
696
      "outputs": []
697
    },
698
    {
699
      "cell_type": "markdown",
700
      "metadata": {
701
        "id": "aPhP2TZdA_5l"
702
      },
703
      "source": [
704
        "# Answers Bleu and Rouge"
705
      ]
706
    },
707
    {
708
      "cell_type": "code",
709
      "metadata": {
710
        "id": "nuH1yj2f5rer",
711
        "colab": {
712
          "base_uri": "https://localhost:8080/"
713
        },
714
        "outputId": "88e86ed4-15f2-4f6e-af14-1ffed31c3500"
715
      },
716
      "source": [
717
        "bleu_scores = []\n",
718
        "rouge1_scores = []\n",
719
        "rouge2_scores = []\n",
720
        "rougel_scores = []\n",
721
        "context_detection = []\n",
722
        "\n",
723
        "rouge = Rouge()\n",
724
        "smoothie = SmoothingFunction().method4\n",
725
        "\n",
726
        "for data in tqdm(qa):\n",
727
        "    true_context = data['context']\n",
728
        "    true_context = true_context.replace('\\n', ' ')\n",
729
        "\n",
730
        "    for q_a in data['qas']:\n",
731
        "        question = q_a['question']\n",
732
        "        reference = \" \".join(q_a['answers'])\n",
733
        "        preds = pipe.run(\n",
734
        "            query=question,\n",
735
        "            params={\"Retriever\": {\"top_k\": 3}, \"Reader\": {\"top_k\": 2}}\n",
736
        "            )\n",
737
        "        \n",
738
        "        candidate_sent_list = []\n",
739
        "        pred_context_list = [pred.to_dict()['text'] for pred in preds['documents']]\n",
740
        "        pred_context = ' '.join(pred_context_list)\n",
741
        "        pred_context = pred_context.replace('\\n', ' ')\n",
742
        "        doc = nlp(pred_context)\n",
743
        "        pred_context_sents = list(doc.sents)\n",
744
        "\n",
745
        "        for pred_co in pred_context_list:\n",
746
        "            pred_co = \"\".join(pred_co.rstrip().lstrip())\n",
747
        "            if pred_co in true_context:\n",
748
        "                context_detection.append(1)\n",
749
        "            else:\n",
750
        "                context_detection.append(0)\n",
751
        "\n",
752
        "        for pred in preds['answers']:\n",
753
        "            pred_answer = pred['answer']\n",
754
        "\n",
755
        "            if pred_answer is not None:\n",
756
        "                pred_answer = pred_answer.replace('\\n', ' ')\n",
757
        "                doc = nlp(pred_answer)\n",
758
        "                pred_answer_sents = list(doc.sents)\n",
759
        "\n",
760
        "                for pred_context_sent in pred_context_sents:\n",
761
        "                    for pred_answer_sent in pred_answer_sents:\n",
762
        "                        pred_answer_sent = \"\".join(pred_answer_sent.text.rstrip().lstrip())\n",
763
        "\n",
764
        "                        if pred_answer_sent in pred_context_sent.text:\n",
765
        "                            candidate_sent_list.append(pred_context_sent.text)\n",
766
        "\n",
767
        "\n",
768
        "        candidate_sent_set = set(candidate_sent_list)\n",
769
        "        candidate = \" \".join(candidate_sent_set)\n",
770
        "        token_reference = nltk.word_tokenize(reference)\n",
771
        "        token_candidate = nltk.word_tokenize(candidate)\n",
772
        "\n",
773
        "        bleu_score = sentence_bleu(token_reference, \n",
774
        "                                    token_candidate, \n",
775
        "                                    smoothing_function=smoothie, \n",
776
        "                                    weights=(1, 0, 0, 0))\n",
777
        "        rouge_score = rouge.get_scores(candidate, reference)\n",
778
        "\n",
779
        "        bleu_scores.append(bleu_score)\n",
780
        "        rouge1_scores.append(rouge_score[0]['rouge-1']['f'])\n",
781
        "        rouge2_scores.append(rouge_score[0]['rouge-2']['f'])\n",
782
        "        rougel_scores.append(rouge_score[0]['rouge-l']['f'])"
783
      ],
784
      "execution_count": 166,
785
      "outputs": [
786
        {
787
          "output_type": "stream",
788
          "name": "stderr",
789
          "text": [
790
            "100%|██████████| 38/38 [10:04<00:00, 15.90s/it]\n"
791
          ]
792
        }
793
      ]
794
    },
795
    {
796
      "cell_type": "code",
797
      "metadata": {
798
        "colab": {
799
          "base_uri": "https://localhost:8080/"
800
        },
801
        "id": "fbNvejJDz_Pf",
802
        "outputId": "a94abda7-6172-4bd8-efcf-806f12647fda"
803
      },
804
      "source": [
805
        "context_detection.count(1) / len(context_detection)"
806
      ],
807
      "execution_count": 167,
808
      "outputs": [
809
        {
810
          "output_type": "execute_result",
811
          "data": {
812
            "text/plain": [
813
              "0.046413502109704644"
814
            ]
815
          },
816
          "metadata": {},
817
          "execution_count": 167
818
        }
819
      ]
820
    },
821
    {
822
      "cell_type": "code",
823
      "metadata": {
824
        "id": "TJoBQVepyXVo",
825
        "colab": {
826
          "base_uri": "https://localhost:8080/"
827
        },
828
        "outputId": "91927d4e-aad4-4f83-c0f8-393b213d0274"
829
      },
830
      "source": [
831
        "print(\"bleu -->\", sum(bleu_scores)/len(bleu_scores))\n",
832
        "print(\"rouge1 -->\", sum(rouge1_scores)/len(rouge1_scores))\n",
833
        "print(\"rouge2 -->\", sum(rouge2_scores)/len(rouge2_scores))\n",
834
        "print(\"rougel -->\", sum(rougel_scores)/len(rougel_scores))"
835
      ],
836
      "execution_count": 168,
837
      "outputs": [
838
        {
839
          "output_type": "stream",
840
          "name": "stdout",
841
          "text": [
842
            "bleu --> 0.0700302475143219\n",
843
            "rouge1 --> 0.2025044353996794\n",
844
            "rouge2 --> 0.11069682602623314\n",
845
            "rougel --> 0.1903201590307057\n"
846
          ]
847
        }
848
      ]
849
    },
850
    {
851
      "cell_type": "markdown",
852
      "metadata": {
853
        "id": "aa7Wjif33yoC"
854
      },
855
      "source": [
856
        "# facebook/dpr-question_encoder-single-nq-base + Fine Tuned Bert on (Squad + Our Dataset)"
857
      ]
858
    },
859
    {
860
      "cell_type": "markdown",
861
      "metadata": {
862
        "id": "SJqLRURR5Nwu"
863
      },
864
      "source": [
865
        "# 3, 2\n",
866
        "\n",
867
        "```\n",
868
        "bleu --> 0.06931287859597637\n",
869
        "rouge1 --> 0.19821744629020724\n",
870
        "rouge2 --> 0.10658866102635696\n",
871
        "rougel --> 0.1868117643779736\n",
872
        "```"
873
      ]
874
    },
875
    {
876
      "cell_type": "markdown",
877
      "metadata": {
878
        "id": "-3mFCty4EXQ6"
879
      },
880
      "source": [
881
        "# 3, 5\n",
882
        "\n",
883
        "```\n",
884
        "bleu --> 0.03326668172125407\n",
885
        "rouge1 --> 0.20946994043485553\n",
886
        "rouge2 --> 0.10167689243447016\n",
887
        "rougel --> 0.19906621627604376\n",
888
        "```"
889
      ]
890
    },
891
    {
892
      "cell_type": "markdown",
893
      "metadata": {
894
        "id": "isTlazUFE1V8"
895
      },
896
      "source": [
897
        "# 3, 7\n",
898
        "\n",
899
        "```\n",
900
        "bleu --> 0.02560506763859031\n",
901
        "rouge1 --> 0.19673322385299405\n",
902
        "rouge2 --> 0.08888356388695941\n",
903
        "rougel --> 0.18626961745270976\n",
904
        "```"
905
      ]
906
    },
907
    {
908
      "cell_type": "markdown",
909
      "metadata": {
910
        "id": "g1wJpZ8r4GlW"
911
      },
912
      "source": [
913
        "# 5, 2\n",
914
        "\n",
915
        "```\n",
916
        "bleu --> 0.06345328647077997\n",
917
        "rouge1 --> 0.20406716478695625\n",
918
        "rouge2 --> 0.1060285107744637\n",
919
        "rougel --> 0.19283290551462437\n",
920
        "```"
921
      ]
922
    },
923
    {
924
      "cell_type": "markdown",
925
      "metadata": {
926
        "id": "vz93n0QzGKeF"
927
      },
928
      "source": [
929
        "# 5, 3\n",
930
        "\n",
931
        "```\n",
932
        "bleu --> 0.04911886384092217\n",
933
        "rouge1 --> 0.21052577428485436\n",
934
        "rouge2 --> 0.10274851606324212\n",
935
        "rougel --> 0.19822657706745186\n",
936
        "```"
937
      ]
938
    },
939
    {
940
      "cell_type": "markdown",
941
      "metadata": {
942
        "id": "bQlxsR8cFfDL"
943
      },
944
      "source": [
945
        "# 5, 5\n",
946
        "\n",
947
        "```\n",
948
        "bleu --> 0.03170644661963682\n",
949
        "rouge1 --> 0.21191987555043731\n",
950
        "rouge2 --> 0.1037352111344695\n",
951
        "rougel --> 0.20209905658726562\n",
952
        "```"
953
      ]
954
    },
955
    {
956
      "cell_type": "markdown",
957
      "metadata": {
958
        "id": "LYliZ6kU4_IP"
959
      },
960
      "source": [
961
        "#10, 2\n",
962
        "\n",
963
        "```\n",
964
        "bleu --> 0.057428091668584556\n",
965
        "rouge1 --> 0.20646246870472995\n",
966
        "rouge2 --> 0.10519838762813453\n",
967
        "rougel --> 0.1950256050028359\n",
968
        "```"
969
      ]
970
    },
971
    {
972
      "cell_type": "markdown",
973
      "metadata": {
974
        "id": "wkQtE2rArKnM"
975
      },
976
      "source": [
977
        "# 10, 5\n",
978
        "\n",
979
        "```\n",
980
        "bleu --> 0.028692153647818627\n",
981
        "rouge1 --> 0.21279947143169525\n",
982
        "rouge2 --> 0.1004407817858144\n",
983
        "rougel --> 0.20142881253757677\n",
984
        "```"
985
      ]
986
    },
987
    {
988
      "cell_type": "markdown",
989
      "metadata": {
990
        "id": "Nvfq3joB6AgT"
991
      },
992
      "source": [
993
        "# ktrapeznikov/albert-xlarge-v2-squad-v2 + Fine Tuned Bert on (Squad + Our Dataset)"
994
      ]
995
    },
996
    {
997
      "cell_type": "markdown",
998
      "metadata": {
999
        "id": "cvTGLrvqARNL"
1000
      },
1001
      "source": [
1002
        "# 3, 2\n",
1003
        "\n",
1004
        "```\n",
1005
        "bleu --> 0.0700302475143219\n",
1006
        "rouge1 --> 0.2025044353996794\n",
1007
        "rouge2 --> 0.11069682602623314\n",
1008
        "rougel --> 0.1903201590307057\n",
1009
        "```"
1010
      ]
1011
    },
1012
    {
1013
      "cell_type": "code",
1014
      "metadata": {
1015
        "id": "NsMM-t20CsPz"
1016
      },
1017
      "source": [
1018
        ""
1019
      ],
1020
      "execution_count": null,
1021
      "outputs": []
1022
    }
1023
  ]
1024
}