[507a54]: / development / qa-server / DPR_extractive_QA.ipynb

Download this file

1024 lines (1024 with data), 37.5 kB

{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "DPR_extractive_QA.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "adb6799fa25e4bc7adf1212f63f04973": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_d79b5216eea84819adca4b226b63cdf2",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_a31c486541af49d29a33720649112c4d",
              "IPY_MODEL_daf91846ad164d55a741bf9c7894d7de",
              "IPY_MODEL_6614bcfbf2e64961b8d4a36a8fa0942f"
            ]
          }
        },
        "d79b5216eea84819adca4b226b63cdf2": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "a31c486541af49d29a33720649112c4d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_5295e24c499249ee814331ddd6d0f984",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": "Create embeddings: 100%",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_8eb53342f51c4307b5c2f98de2baf3e0"
          }
        },
        "daf91846ad164d55a741bf9c7894d7de": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_1e85d3a19e6941c08a43fb531cf407dc",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "",
            "max": 7344,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 7344,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_982358c88a9a44b48ae915b4e3b9b46b"
          }
        },
        "6614bcfbf2e64961b8d4a36a8fa0942f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_1c518d3f76424c609b6f085ae15de0c8",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 7328/7344 [03:48<00:00, 32.26 Docs/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_b2aaec2787ff4e17bddb01515d5d7068"
          }
        },
        "5295e24c499249ee814331ddd6d0f984": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "8eb53342f51c4307b5c2f98de2baf3e0": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "1e85d3a19e6941c08a43fb531cf407dc": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "982358c88a9a44b48ae915b4e3b9b46b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "1c518d3f76424c609b6f085ae15de0c8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "b2aaec2787ff4e17bddb01515d5d7068": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        }
      }
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "Y0R8IwSCWYVk"
      },
      "source": [
        "!pip install -q -U transformers[sentencepiece] rouge git+https://github.com/deepset-ai/haystack.git grpcio-tools==1.34.1 spacy"
      ],
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1EHR4i9FiCAa"
      },
      "source": [
        "import spacy\n",
        "import nltk\n",
        "import json\n",
        "from tqdm import tqdm\n",
        "import pandas as pd \n",
        "from rouge import Rouge\n",
        "from pprint import pprint\n",
        "from typing import List\n",
        "from haystack import Document\n",
        "from haystack.reader import TransformersReader\n",
        "from haystack.pipeline import ExtractiveQAPipeline \n",
        "from haystack.retriever.dense import DensePassageRetriever \n",
        "from haystack.document_store.faiss import FAISSDocumentStore\n",
        "from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction"
      ],
      "execution_count": 12,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gyHgtCYAXEZb",
        "outputId": "824bdab5-4c5b-4064-8efd-b0bfc2798827"
      },
      "source": [
        "!spacy download en_core_web_md \n",
        "!spacy link en_core_web_md en"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting en-core-web-md==3.1.0\n",
            "  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.1.0/en_core_web_md-3.1.0-py3-none-any.whl (45.4 MB)\n",
            "\u001b[K     |████████████████████████████████| 45.4 MB 17 kB/s \n",
            "\u001b[?25hRequirement already satisfied: spacy<3.2.0,>=3.1.0 in /usr/local/lib/python3.7/dist-packages (from en-core-web-md==3.1.0) (3.1.3)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (21.0)\n",
            "Requirement already satisfied: typer<0.5.0,>=0.3.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (0.4.0)\n",
            "Requirement already satisfied: pydantic!=1.8,!=1.8.1,<1.9.0,>=1.7.4 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (1.8.2)\n",
            "Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.0.5)\n",
            "Requirement already satisfied: blis<0.8.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (0.4.1)\n",
            "Requirement already satisfied: wasabi<1.1.0,>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (0.8.2)\n",
            "Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (1.19.5)\n",
            "Requirement already satisfied: pathy>=0.3.5 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (0.6.0)\n",
            "Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (4.62.2)\n",
            "Requirement already satisfied: thinc<8.1.0,>=8.0.9 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (8.0.10)\n",
            "Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.8 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (3.0.8)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.11.3)\n",
            "Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (3.0.5)\n",
            "Requirement already satisfied: srsly<3.0.0,>=2.4.1 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.4.1)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (57.4.0)\n",
            "Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.0.6)\n",
            "Requirement already satisfied: typing-extensions<4.0.0.0,>=3.7.4 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (3.7.4.3)\n",
            "Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.23.0)\n",
            "Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (1.0.5)\n",
            "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from catalogue<2.1.0,>=2.0.6->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (3.5.0)\n",
            "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.4.7)\n",
            "Requirement already satisfied: smart-open<6.0.0,>=5.0.0 in /usr/local/lib/python3.7/dist-packages (from pathy>=0.3.5->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (5.2.1)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (1.25.11)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2021.5.30)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.10)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (3.0.4)\n",
            "Requirement already satisfied: click<9.0.0,>=7.1.1 in /usr/local/lib/python3.7/dist-packages (from typer<0.5.0,>=0.3.0->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (7.1.2)\n",
            "Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.0.1)\n",
            "\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n",
            "You can now load the package via spacy.load('en_core_web_md')\n",
            "\u001b[31mDeprecationWarning: The command link is deprecated.\u001b[0m\n",
            "\u001b[38;5;3m⚠ As of spaCy v3.0, model symlinks are not supported anymore. You can\n",
            "load trained pipeline packages using their full names or from a directory\n",
            "path.\u001b[0m\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "K7ipay4bguxX",
        "outputId": "a54705c6-fea4-4533-a4c9-638e46d49c89"
      },
      "source": [
        "nltk.download('punkt')"
      ],
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[nltk_data] Downloading package punkt to /root/nltk_data...\n",
            "[nltk_data]   Unzipping tokenizers/punkt.zip.\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "True"
            ]
          },
          "metadata": {},
          "execution_count": 22
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fhLYC0ivc0A_",
        "outputId": "a48b2a6c-f1f9-4ab3-e14f-f777c295901c"
      },
      "source": [
        "import spacy\n",
        "# nlp = spacy.load('en_core_web_md')\n",
        "nlp = English() \n",
        "nlp.add_pipe(\"sentencizer\")"
      ],
      "execution_count": 61,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<spacy.pipeline.sentencizer.Sentencizer at 0x7f4ea4ea7640>"
            ]
          },
          "metadata": {},
          "execution_count": 61
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ZkXrmWA1ViH-",
        "outputId": "bf834313-8fff-49bf-9cce-cab06f0f7d1d"
      },
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mounted at /content/drive\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "S8XOUzQpwLw9"
      },
      "source": [
        "with open('drive/MyDrive/qa_test.json', \"r\") as f:\n",
        "    qa = json.loads(f.read())['data']\n",
        "\n",
        "df = pd.read_csv('drive/MyDrive/ex-QA.csv', index_col=0)\n",
        "df = df.replace(r'\\n',' ', regex=True) "
      ],
      "execution_count": 16,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fbsJedAQ1mBF"
      },
      "source": [
        "titles = list(df[\"title\"].values)\n",
        "texts  = list(df[\"text\"].values)\n",
        "documents: List[Document] = []\n",
        " \n",
        "for title, text in zip(titles, texts):\n",
        "    documents.append(\n",
        "        Document(\n",
        "            text=text,\n",
        "            meta={\n",
        "                \"name\": title or \"\"\n",
        "            }\n",
        "        )\n",
        "    )"
      ],
      "execution_count": 17,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 108,
          "referenced_widgets": [
            "adb6799fa25e4bc7adf1212f63f04973",
            "d79b5216eea84819adca4b226b63cdf2",
            "a31c486541af49d29a33720649112c4d",
            "daf91846ad164d55a741bf9c7894d7de",
            "6614bcfbf2e64961b8d4a36a8fa0942f",
            "5295e24c499249ee814331ddd6d0f984",
            "8eb53342f51c4307b5c2f98de2baf3e0",
            "1e85d3a19e6941c08a43fb531cf407dc",
            "982358c88a9a44b48ae915b4e3b9b46b",
            "1c518d3f76424c609b6f085ae15de0c8",
            "b2aaec2787ff4e17bddb01515d5d7068"
          ]
        },
        "id": "pDemzc4-kPww",
        "outputId": "85a2e13c-4c01-4238-ce7b-576247e989a6"
      },
      "source": [
        "document_store = FAISSDocumentStore(\n",
        "    faiss_index_factory_str=\"Flat\",\n",
        "    return_embedding=True\n",
        ")\n",
        "\n",
        "retriever = DensePassageRetriever(\n",
        "    document_store=document_store,\n",
        "    query_embedding_model=\"facebook/dpr-question_encoder-single-nq-base\",\n",
        "    passage_embedding_model=\"facebook/dpr-ctx_encoder-single-nq-base\",\n",
        "    use_gpu=True,\n",
        "    embed_title=True,\n",
        ")\n",
        "\n",
        "# retriever = DensePassageRetriever(\n",
        "#     document_store=document_store,\n",
        "#     query_embedding_model=\"drive/MyDrive/bert-large-finetuned\",\n",
        "#     passage_embedding_model=\"drive/MyDrive/bert-large-finetuned\",\n",
        "#     use_gpu=True,\n",
        "#     embed_title=True,\n",
        "# )\n",
        "\n",
        "document_store.delete_documents()\n",
        "document_store.write_documents(documents)\n",
        "document_store.update_embeddings(\n",
        "    retriever=retriever\n",
        ")"
      ],
      "execution_count": 85,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "09/24/2021 21:45:40 - WARNING - farm.utils -   Failed to log params: Changing param values is not allowed. Param with key='lm1_name' was already logged with value='drive/MyDrive/squeeze-bert-finetuned' for run ID='57f858c4db2047c3a6013ad6593d2c9b'. Attempted logging new value 'facebook/dpr-question_encoder-single-nq-base'.\n",
            "09/24/2021 21:46:02 - INFO - haystack.document_store.faiss -   Updating embeddings for 7330 docs...\n",
            "Updating Embedding:   0%|          | 0/7330 [00:00<?, ? docs/s]"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "adb6799fa25e4bc7adf1212f63f04973",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "Create embeddings:   0%|          | 0/7344 [00:00<?, ? Docs/s]"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Documents Processed: 10000 docs [03:58, 41.92 docs/s]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TifruTY50Lsn"
      },
      "source": [
        "# reader = TransformersReader(model_name_or_path=\"ahotrod/albert_xxlargev1_squad2_512\", use_gpu=0)\n",
        "# reader = TransformersReader(model_name_or_path=\"bert-large-uncased-whole-word-masking-finetuned-squad\", use_gpu=0)\n",
        "\n",
        "reader = TransformersReader(model_name_or_path=\"ktrapeznikov/albert-xlarge-v2-squad-v2\", \n",
        "                            context_window_size=70,\n",
        "                            max_seq_len=256,\n",
        "                            doc_stride=128,\n",
        "                            use_gpu=0)\n",
        "\n",
        "# reader = TransformersReader(model_name_or_path=\"drive/MyDrive/bert_basefi_qafi\", \n",
        "#                             context_window_size=70,\n",
        "#                             max_seq_len=256,\n",
        "#                             doc_stride=128,\n",
        "#                             use_gpu=0)\n",
        "\n",
        "pipe = ExtractiveQAPipeline(reader, retriever)"
      ],
      "execution_count": 165,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aPhP2TZdA_5l"
      },
      "source": [
        "# Answers Bleu and Rouge"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nuH1yj2f5rer",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "88e86ed4-15f2-4f6e-af14-1ffed31c3500"
      },
      "source": [
        "bleu_scores = []\n",
        "rouge1_scores = []\n",
        "rouge2_scores = []\n",
        "rougel_scores = []\n",
        "context_detection = []\n",
        "\n",
        "rouge = Rouge()\n",
        "smoothie = SmoothingFunction().method4\n",
        "\n",
        "for data in tqdm(qa):\n",
        "    true_context = data['context']\n",
        "    true_context = true_context.replace('\\n', ' ')\n",
        "\n",
        "    for q_a in data['qas']:\n",
        "        question = q_a['question']\n",
        "        reference = \" \".join(q_a['answers'])\n",
        "        preds = pipe.run(\n",
        "            query=question,\n",
        "            params={\"Retriever\": {\"top_k\": 3}, \"Reader\": {\"top_k\": 2}}\n",
        "            )\n",
        "        \n",
        "        candidate_sent_list = []\n",
        "        pred_context_list = [pred.to_dict()['text'] for pred in preds['documents']]\n",
        "        pred_context = ' '.join(pred_context_list)\n",
        "        pred_context = pred_context.replace('\\n', ' ')\n",
        "        doc = nlp(pred_context)\n",
        "        pred_context_sents = list(doc.sents)\n",
        "\n",
        "        for pred_co in pred_context_list:\n",
        "            pred_co = \"\".join(pred_co.rstrip().lstrip())\n",
        "            if pred_co in true_context:\n",
        "                context_detection.append(1)\n",
        "            else:\n",
        "                context_detection.append(0)\n",
        "\n",
        "        for pred in preds['answers']:\n",
        "            pred_answer = pred['answer']\n",
        "\n",
        "            if pred_answer is not None:\n",
        "                pred_answer = pred_answer.replace('\\n', ' ')\n",
        "                doc = nlp(pred_answer)\n",
        "                pred_answer_sents = list(doc.sents)\n",
        "\n",
        "                for pred_context_sent in pred_context_sents:\n",
        "                    for pred_answer_sent in pred_answer_sents:\n",
        "                        pred_answer_sent = \"\".join(pred_answer_sent.text.rstrip().lstrip())\n",
        "\n",
        "                        if pred_answer_sent in pred_context_sent.text:\n",
        "                            candidate_sent_list.append(pred_context_sent.text)\n",
        "\n",
        "\n",
        "        candidate_sent_set = set(candidate_sent_list)\n",
        "        candidate = \" \".join(candidate_sent_set)\n",
        "        token_reference = nltk.word_tokenize(reference)\n",
        "        token_candidate = nltk.word_tokenize(candidate)\n",
        "\n",
        "        bleu_score = sentence_bleu(token_reference, \n",
        "                                    token_candidate, \n",
        "                                    smoothing_function=smoothie, \n",
        "                                    weights=(1, 0, 0, 0))\n",
        "        rouge_score = rouge.get_scores(candidate, reference)\n",
        "\n",
        "        bleu_scores.append(bleu_score)\n",
        "        rouge1_scores.append(rouge_score[0]['rouge-1']['f'])\n",
        "        rouge2_scores.append(rouge_score[0]['rouge-2']['f'])\n",
        "        rougel_scores.append(rouge_score[0]['rouge-l']['f'])"
      ],
      "execution_count": 166,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 38/38 [10:04<00:00, 15.90s/it]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fbNvejJDz_Pf",
        "outputId": "a94abda7-6172-4bd8-efcf-806f12647fda"
      },
      "source": [
        "context_detection.count(1) / len(context_detection)"
      ],
      "execution_count": 167,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.046413502109704644"
            ]
          },
          "metadata": {},
          "execution_count": 167
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TJoBQVepyXVo",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "91927d4e-aad4-4f83-c0f8-393b213d0274"
      },
      "source": [
        "print(\"bleu -->\", sum(bleu_scores)/len(bleu_scores))\n",
        "print(\"rouge1 -->\", sum(rouge1_scores)/len(rouge1_scores))\n",
        "print(\"rouge2 -->\", sum(rouge2_scores)/len(rouge2_scores))\n",
        "print(\"rougel -->\", sum(rougel_scores)/len(rougel_scores))"
      ],
      "execution_count": 168,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "bleu --> 0.0700302475143219\n",
            "rouge1 --> 0.2025044353996794\n",
            "rouge2 --> 0.11069682602623314\n",
            "rougel --> 0.1903201590307057\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aa7Wjif33yoC"
      },
      "source": [
        "# facebook/dpr-question_encoder-single-nq-base + Fine Tuned Bert on (Squad + Our Dataset)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SJqLRURR5Nwu"
      },
      "source": [
        "# 3, 2\n",
        "\n",
        "```\n",
        "bleu --> 0.06931287859597637\n",
        "rouge1 --> 0.19821744629020724\n",
        "rouge2 --> 0.10658866102635696\n",
        "rougel --> 0.1868117643779736\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-3mFCty4EXQ6"
      },
      "source": [
        "# 3, 5\n",
        "\n",
        "```\n",
        "bleu --> 0.03326668172125407\n",
        "rouge1 --> 0.20946994043485553\n",
        "rouge2 --> 0.10167689243447016\n",
        "rougel --> 0.19906621627604376\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "isTlazUFE1V8"
      },
      "source": [
        "# 3, 7\n",
        "\n",
        "```\n",
        "bleu --> 0.02560506763859031\n",
        "rouge1 --> 0.19673322385299405\n",
        "rouge2 --> 0.08888356388695941\n",
        "rougel --> 0.18626961745270976\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g1wJpZ8r4GlW"
      },
      "source": [
        "# 5, 2\n",
        "\n",
        "```\n",
        "bleu --> 0.06345328647077997\n",
        "rouge1 --> 0.20406716478695625\n",
        "rouge2 --> 0.1060285107744637\n",
        "rougel --> 0.19283290551462437\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vz93n0QzGKeF"
      },
      "source": [
        "# 5, 3\n",
        "\n",
        "```\n",
        "bleu --> 0.04911886384092217\n",
        "rouge1 --> 0.21052577428485436\n",
        "rouge2 --> 0.10274851606324212\n",
        "rougel --> 0.19822657706745186\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bQlxsR8cFfDL"
      },
      "source": [
        "# 5, 5\n",
        "\n",
        "```\n",
        "bleu --> 0.03170644661963682\n",
        "rouge1 --> 0.21191987555043731\n",
        "rouge2 --> 0.1037352111344695\n",
        "rougel --> 0.20209905658726562\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LYliZ6kU4_IP"
      },
      "source": [
        "#10, 2\n",
        "\n",
        "```\n",
        "bleu --> 0.057428091668584556\n",
        "rouge1 --> 0.20646246870472995\n",
        "rouge2 --> 0.10519838762813453\n",
        "rougel --> 0.1950256050028359\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wkQtE2rArKnM"
      },
      "source": [
        "# 10, 5\n",
        "\n",
        "```\n",
        "bleu --> 0.028692153647818627\n",
        "rouge1 --> 0.21279947143169525\n",
        "rouge2 --> 0.1004407817858144\n",
        "rougel --> 0.20142881253757677\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Nvfq3joB6AgT"
      },
      "source": [
        "# ktrapeznikov/albert-xlarge-v2-squad-v2 + Fine Tuned Bert on (Squad + Our Dataset)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cvTGLrvqARNL"
      },
      "source": [
        "# 3, 2\n",
        "\n",
        "```\n",
        "bleu --> 0.0700302475143219\n",
        "rouge1 --> 0.2025044353996794\n",
        "rouge2 --> 0.11069682602623314\n",
        "rougel --> 0.1903201590307057\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NsMM-t20CsPz"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}