1024 lines (1024 with data), 37.5 kB
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "DPR_extractive_QA.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"adb6799fa25e4bc7adf1212f63f04973": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_d79b5216eea84819adca4b226b63cdf2",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_a31c486541af49d29a33720649112c4d",
"IPY_MODEL_daf91846ad164d55a741bf9c7894d7de",
"IPY_MODEL_6614bcfbf2e64961b8d4a36a8fa0942f"
]
}
},
"d79b5216eea84819adca4b226b63cdf2": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"a31c486541af49d29a33720649112c4d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_5295e24c499249ee814331ddd6d0f984",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": "Create embeddings: 100%",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_8eb53342f51c4307b5c2f98de2baf3e0"
}
},
"daf91846ad164d55a741bf9c7894d7de": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_1e85d3a19e6941c08a43fb531cf407dc",
"_dom_classes": [],
"description": "",
"_model_name": "FloatProgressModel",
"bar_style": "",
"max": 7344,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 7344,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_982358c88a9a44b48ae915b4e3b9b46b"
}
},
"6614bcfbf2e64961b8d4a36a8fa0942f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_1c518d3f76424c609b6f085ae15de0c8",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 7328/7344 [03:48<00:00, 32.26 Docs/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_b2aaec2787ff4e17bddb01515d5d7068"
}
},
"5295e24c499249ee814331ddd6d0f984": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"8eb53342f51c4307b5c2f98de2baf3e0": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"1e85d3a19e6941c08a43fb531cf407dc": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"982358c88a9a44b48ae915b4e3b9b46b": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"1c518d3f76424c609b6f085ae15de0c8": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"b2aaec2787ff4e17bddb01515d5d7068": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
}
}
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "Y0R8IwSCWYVk"
},
"source": [
"!pip install -q -U transformers[sentencepiece] rouge git+https://github.com/deepset-ai/haystack.git grpcio-tools==1.34.1 spacy"
],
"execution_count": 11,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "1EHR4i9FiCAa"
},
"source": [
"import spacy\n",
"import nltk\n",
"import json\n",
"from tqdm import tqdm\n",
"import pandas as pd \n",
"from rouge import Rouge\n",
"from pprint import pprint\n",
"from typing import List\n",
"from haystack import Document\n",
"from haystack.reader import TransformersReader\n",
"from haystack.pipeline import ExtractiveQAPipeline \n",
"from haystack.retriever.dense import DensePassageRetriever \n",
"from haystack.document_store.faiss import FAISSDocumentStore\n",
"from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction"
],
"execution_count": 12,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gyHgtCYAXEZb",
"outputId": "824bdab5-4c5b-4064-8efd-b0bfc2798827"
},
"source": [
"!spacy download en_core_web_md \n",
"!spacy link en_core_web_md en"
],
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting en-core-web-md==3.1.0\n",
" Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.1.0/en_core_web_md-3.1.0-py3-none-any.whl (45.4 MB)\n",
"\u001b[K |████████████████████████████████| 45.4 MB 17 kB/s \n",
"\u001b[?25hRequirement already satisfied: spacy<3.2.0,>=3.1.0 in /usr/local/lib/python3.7/dist-packages (from en-core-web-md==3.1.0) (3.1.3)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (21.0)\n",
"Requirement already satisfied: typer<0.5.0,>=0.3.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (0.4.0)\n",
"Requirement already satisfied: pydantic!=1.8,!=1.8.1,<1.9.0,>=1.7.4 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (1.8.2)\n",
"Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.0.5)\n",
"Requirement already satisfied: blis<0.8.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (0.4.1)\n",
"Requirement already satisfied: wasabi<1.1.0,>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (0.8.2)\n",
"Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (1.19.5)\n",
"Requirement already satisfied: pathy>=0.3.5 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (0.6.0)\n",
"Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (4.62.2)\n",
"Requirement already satisfied: thinc<8.1.0,>=8.0.9 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (8.0.10)\n",
"Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.8 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (3.0.8)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.11.3)\n",
"Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (3.0.5)\n",
"Requirement already satisfied: srsly<3.0.0,>=2.4.1 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.4.1)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (57.4.0)\n",
"Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.0.6)\n",
"Requirement already satisfied: typing-extensions<4.0.0.0,>=3.7.4 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (3.7.4.3)\n",
"Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.23.0)\n",
"Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (1.0.5)\n",
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from catalogue<2.1.0,>=2.0.6->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (3.5.0)\n",
"Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.4.7)\n",
"Requirement already satisfied: smart-open<6.0.0,>=5.0.0 in /usr/local/lib/python3.7/dist-packages (from pathy>=0.3.5->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (5.2.1)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (1.25.11)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2021.5.30)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.10)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (3.0.4)\n",
"Requirement already satisfied: click<9.0.0,>=7.1.1 in /usr/local/lib/python3.7/dist-packages (from typer<0.5.0,>=0.3.0->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (7.1.2)\n",
"Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2->spacy<3.2.0,>=3.1.0->en-core-web-md==3.1.0) (2.0.1)\n",
"\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n",
"You can now load the package via spacy.load('en_core_web_md')\n",
"\u001b[31mDeprecationWarning: The command link is deprecated.\u001b[0m\n",
"\u001b[38;5;3m⚠ As of spaCy v3.0, model symlinks are not supported anymore. You can\n",
"load trained pipeline packages using their full names or from a directory\n",
"path.\u001b[0m\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "K7ipay4bguxX",
"outputId": "a54705c6-fea4-4533-a4c9-638e46d49c89"
},
"source": [
"nltk.download('punkt')"
],
"execution_count": 22,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[nltk_data] Downloading package punkt to /root/nltk_data...\n",
"[nltk_data] Unzipping tokenizers/punkt.zip.\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"True"
]
},
"metadata": {},
"execution_count": 22
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fhLYC0ivc0A_",
"outputId": "a48b2a6c-f1f9-4ab3-e14f-f777c295901c"
},
"source": [
"import spacy\n",
"# nlp = spacy.load('en_core_web_md')\n",
"nlp = English() \n",
"nlp.add_pipe(\"sentencizer\")"
],
"execution_count": 61,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<spacy.pipeline.sentencizer.Sentencizer at 0x7f4ea4ea7640>"
]
},
"metadata": {},
"execution_count": 61
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ZkXrmWA1ViH-",
"outputId": "bf834313-8fff-49bf-9cce-cab06f0f7d1d"
},
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
],
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mounted at /content/drive\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "S8XOUzQpwLw9"
},
"source": [
"with open('drive/MyDrive/qa_test.json', \"r\") as f:\n",
" qa = json.loads(f.read())['data']\n",
"\n",
"df = pd.read_csv('drive/MyDrive/ex-QA.csv', index_col=0)\n",
"df = df.replace(r'\\n',' ', regex=True) "
],
"execution_count": 16,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "fbsJedAQ1mBF"
},
"source": [
"titles = list(df[\"title\"].values)\n",
"texts = list(df[\"text\"].values)\n",
"documents: List[Document] = []\n",
" \n",
"for title, text in zip(titles, texts):\n",
" documents.append(\n",
" Document(\n",
" text=text,\n",
" meta={\n",
" \"name\": title or \"\"\n",
" }\n",
" )\n",
" )"
],
"execution_count": 17,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 108,
"referenced_widgets": [
"adb6799fa25e4bc7adf1212f63f04973",
"d79b5216eea84819adca4b226b63cdf2",
"a31c486541af49d29a33720649112c4d",
"daf91846ad164d55a741bf9c7894d7de",
"6614bcfbf2e64961b8d4a36a8fa0942f",
"5295e24c499249ee814331ddd6d0f984",
"8eb53342f51c4307b5c2f98de2baf3e0",
"1e85d3a19e6941c08a43fb531cf407dc",
"982358c88a9a44b48ae915b4e3b9b46b",
"1c518d3f76424c609b6f085ae15de0c8",
"b2aaec2787ff4e17bddb01515d5d7068"
]
},
"id": "pDemzc4-kPww",
"outputId": "85a2e13c-4c01-4238-ce7b-576247e989a6"
},
"source": [
"document_store = FAISSDocumentStore(\n",
" faiss_index_factory_str=\"Flat\",\n",
" return_embedding=True\n",
")\n",
"\n",
"retriever = DensePassageRetriever(\n",
" document_store=document_store,\n",
" query_embedding_model=\"facebook/dpr-question_encoder-single-nq-base\",\n",
" passage_embedding_model=\"facebook/dpr-ctx_encoder-single-nq-base\",\n",
" use_gpu=True,\n",
" embed_title=True,\n",
")\n",
"\n",
"# retriever = DensePassageRetriever(\n",
"# document_store=document_store,\n",
"# query_embedding_model=\"drive/MyDrive/bert-large-finetuned\",\n",
"# passage_embedding_model=\"drive/MyDrive/bert-large-finetuned\",\n",
"# use_gpu=True,\n",
"# embed_title=True,\n",
"# )\n",
"\n",
"document_store.delete_documents()\n",
"document_store.write_documents(documents)\n",
"document_store.update_embeddings(\n",
" retriever=retriever\n",
")"
],
"execution_count": 85,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"09/24/2021 21:45:40 - WARNING - farm.utils - Failed to log params: Changing param values is not allowed. Param with key='lm1_name' was already logged with value='drive/MyDrive/squeeze-bert-finetuned' for run ID='57f858c4db2047c3a6013ad6593d2c9b'. Attempted logging new value 'facebook/dpr-question_encoder-single-nq-base'.\n",
"09/24/2021 21:46:02 - INFO - haystack.document_store.faiss - Updating embeddings for 7330 docs...\n",
"Updating Embedding: 0%| | 0/7330 [00:00<?, ? docs/s]"
]
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "adb6799fa25e4bc7adf1212f63f04973",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"Create embeddings: 0%| | 0/7344 [00:00<?, ? Docs/s]"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Documents Processed: 10000 docs [03:58, 41.92 docs/s]\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "TifruTY50Lsn"
},
"source": [
"# reader = TransformersReader(model_name_or_path=\"ahotrod/albert_xxlargev1_squad2_512\", use_gpu=0)\n",
"# reader = TransformersReader(model_name_or_path=\"bert-large-uncased-whole-word-masking-finetuned-squad\", use_gpu=0)\n",
"\n",
"reader = TransformersReader(model_name_or_path=\"ktrapeznikov/albert-xlarge-v2-squad-v2\", \n",
" context_window_size=70,\n",
" max_seq_len=256,\n",
" doc_stride=128,\n",
" use_gpu=0)\n",
"\n",
"# reader = TransformersReader(model_name_or_path=\"drive/MyDrive/bert_basefi_qafi\", \n",
"# context_window_size=70,\n",
"# max_seq_len=256,\n",
"# doc_stride=128,\n",
"# use_gpu=0)\n",
"\n",
"pipe = ExtractiveQAPipeline(reader, retriever)"
],
"execution_count": 165,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "aPhP2TZdA_5l"
},
"source": [
"# Answers Bleu and Rouge"
]
},
{
"cell_type": "code",
"metadata": {
"id": "nuH1yj2f5rer",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "88e86ed4-15f2-4f6e-af14-1ffed31c3500"
},
"source": [
"bleu_scores = []\n",
"rouge1_scores = []\n",
"rouge2_scores = []\n",
"rougel_scores = []\n",
"context_detection = []\n",
"\n",
"rouge = Rouge()\n",
"smoothie = SmoothingFunction().method4\n",
"\n",
"for data in tqdm(qa):\n",
" true_context = data['context']\n",
" true_context = true_context.replace('\\n', ' ')\n",
"\n",
" for q_a in data['qas']:\n",
" question = q_a['question']\n",
" reference = \" \".join(q_a['answers'])\n",
" preds = pipe.run(\n",
" query=question,\n",
" params={\"Retriever\": {\"top_k\": 3}, \"Reader\": {\"top_k\": 2}}\n",
" )\n",
" \n",
" candidate_sent_list = []\n",
" pred_context_list = [pred.to_dict()['text'] for pred in preds['documents']]\n",
" pred_context = ' '.join(pred_context_list)\n",
" pred_context = pred_context.replace('\\n', ' ')\n",
" doc = nlp(pred_context)\n",
" pred_context_sents = list(doc.sents)\n",
"\n",
" for pred_co in pred_context_list:\n",
" pred_co = \"\".join(pred_co.rstrip().lstrip())\n",
" if pred_co in true_context:\n",
" context_detection.append(1)\n",
" else:\n",
" context_detection.append(0)\n",
"\n",
" for pred in preds['answers']:\n",
" pred_answer = pred['answer']\n",
"\n",
" if pred_answer is not None:\n",
" pred_answer = pred_answer.replace('\\n', ' ')\n",
" doc = nlp(pred_answer)\n",
" pred_answer_sents = list(doc.sents)\n",
"\n",
" for pred_context_sent in pred_context_sents:\n",
" for pred_answer_sent in pred_answer_sents:\n",
" pred_answer_sent = \"\".join(pred_answer_sent.text.rstrip().lstrip())\n",
"\n",
" if pred_answer_sent in pred_context_sent.text:\n",
" candidate_sent_list.append(pred_context_sent.text)\n",
"\n",
"\n",
" candidate_sent_set = set(candidate_sent_list)\n",
" candidate = \" \".join(candidate_sent_set)\n",
" token_reference = nltk.word_tokenize(reference)\n",
" token_candidate = nltk.word_tokenize(candidate)\n",
"\n",
" bleu_score = sentence_bleu(token_reference, \n",
" token_candidate, \n",
" smoothing_function=smoothie, \n",
" weights=(1, 0, 0, 0))\n",
" rouge_score = rouge.get_scores(candidate, reference)\n",
"\n",
" bleu_scores.append(bleu_score)\n",
" rouge1_scores.append(rouge_score[0]['rouge-1']['f'])\n",
" rouge2_scores.append(rouge_score[0]['rouge-2']['f'])\n",
" rougel_scores.append(rouge_score[0]['rouge-l']['f'])"
],
"execution_count": 166,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 38/38 [10:04<00:00, 15.90s/it]\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fbNvejJDz_Pf",
"outputId": "a94abda7-6172-4bd8-efcf-806f12647fda"
},
"source": [
"context_detection.count(1) / len(context_detection)"
],
"execution_count": 167,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.046413502109704644"
]
},
"metadata": {},
"execution_count": 167
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "TJoBQVepyXVo",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "91927d4e-aad4-4f83-c0f8-393b213d0274"
},
"source": [
"print(\"bleu -->\", sum(bleu_scores)/len(bleu_scores))\n",
"print(\"rouge1 -->\", sum(rouge1_scores)/len(rouge1_scores))\n",
"print(\"rouge2 -->\", sum(rouge2_scores)/len(rouge2_scores))\n",
"print(\"rougel -->\", sum(rougel_scores)/len(rougel_scores))"
],
"execution_count": 168,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"bleu --> 0.0700302475143219\n",
"rouge1 --> 0.2025044353996794\n",
"rouge2 --> 0.11069682602623314\n",
"rougel --> 0.1903201590307057\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aa7Wjif33yoC"
},
"source": [
"# facebook/dpr-question_encoder-single-nq-base + Fine Tuned Bert on (Squad + Our Dataset)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SJqLRURR5Nwu"
},
"source": [
"# 3, 2\n",
"\n",
"```\n",
"bleu --> 0.06931287859597637\n",
"rouge1 --> 0.19821744629020724\n",
"rouge2 --> 0.10658866102635696\n",
"rougel --> 0.1868117643779736\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-3mFCty4EXQ6"
},
"source": [
"# 3, 5\n",
"\n",
"```\n",
"bleu --> 0.03326668172125407\n",
"rouge1 --> 0.20946994043485553\n",
"rouge2 --> 0.10167689243447016\n",
"rougel --> 0.19906621627604376\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "isTlazUFE1V8"
},
"source": [
"# 3, 7\n",
"\n",
"```\n",
"bleu --> 0.02560506763859031\n",
"rouge1 --> 0.19673322385299405\n",
"rouge2 --> 0.08888356388695941\n",
"rougel --> 0.18626961745270976\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "g1wJpZ8r4GlW"
},
"source": [
"# 5, 2\n",
"\n",
"```\n",
"bleu --> 0.06345328647077997\n",
"rouge1 --> 0.20406716478695625\n",
"rouge2 --> 0.1060285107744637\n",
"rougel --> 0.19283290551462437\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vz93n0QzGKeF"
},
"source": [
"# 5, 3\n",
"\n",
"```\n",
"bleu --> 0.04911886384092217\n",
"rouge1 --> 0.21052577428485436\n",
"rouge2 --> 0.10274851606324212\n",
"rougel --> 0.19822657706745186\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bQlxsR8cFfDL"
},
"source": [
"# 5, 5\n",
"\n",
"```\n",
"bleu --> 0.03170644661963682\n",
"rouge1 --> 0.21191987555043731\n",
"rouge2 --> 0.1037352111344695\n",
"rougel --> 0.20209905658726562\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LYliZ6kU4_IP"
},
"source": [
"#10, 2\n",
"\n",
"```\n",
"bleu --> 0.057428091668584556\n",
"rouge1 --> 0.20646246870472995\n",
"rouge2 --> 0.10519838762813453\n",
"rougel --> 0.1950256050028359\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wkQtE2rArKnM"
},
"source": [
"# 10, 5\n",
"\n",
"```\n",
"bleu --> 0.028692153647818627\n",
"rouge1 --> 0.21279947143169525\n",
"rouge2 --> 0.1004407817858144\n",
"rougel --> 0.20142881253757677\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Nvfq3joB6AgT"
},
"source": [
"# ktrapeznikov/albert-xlarge-v2-squad-v2 + Fine Tuned Bert on (Squad + Our Dataset)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cvTGLrvqARNL"
},
"source": [
"# 3, 2\n",
"\n",
"```\n",
"bleu --> 0.0700302475143219\n",
"rouge1 --> 0.2025044353996794\n",
"rouge2 --> 0.11069682602623314\n",
"rougel --> 0.1903201590307057\n",
"```"
]
},
{
"cell_type": "code",
"metadata": {
"id": "NsMM-t20CsPz"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}