[e056f2]: / nlp_pubmed_rct_classification.ipynb

Download this file

4351 lines (4350 with data), 254.2 kB

{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "nlp-pubmed-rct-classification.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "mount_file_id": "1TGU5mtAosYsHxXnaw9_uAP91Pg9czPX1",
      "authorship_tag": "ABX9TyOt8iFW5V7RdJakaMJ73Mwq",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/hecshzye/nlp-medical-abstract-pubmed-rct/blob/main/nlp_pubmed_rct_classification.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Medical Abstract Classification using Natural Language Processing"
      ],
      "metadata": {
        "id": "Uq_ssycxR8Ar"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### The objective is to build a deep learning model which makes medical research paper abstract easier to read.\n",
        "  - Dataset used in this project is the `PubMed 200k RCT Dataset for Sequential Sentence Classification in Medical Abstract`: https://arxiv.org/abs/1710.06071\n",
        "  - The initial deep learning research paper was built with the PubMed 200k RCT.\n",
        "  - Dataset has about `200,000 labelled Randomized Control Trial abstracts`.\n",
        "  - The goal of the project was build NLP models with the dataset to classify sentences in sequential order."
      ],
      "metadata": {
        "id": "bOGykK3Ad-wp"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        " - As the RCT research papers with unstructured abstracts slows down researchers navigating the literature. \n",
        " - The unstructured abstracts are sometimes hard to read and understand especially when it can disrupt time management and deadlines.\n",
        " - This NLP model can classify the abstract sentences into its respective roles:\n",
        "    - Such as `Objective`, `Methods`, `Results` and `Conclusions`.\n",
        "    "
      ],
      "metadata": {
        "id": "RgYVNMflgo86"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### The PubMed 200k RCT Dataset - https://github.com/Franck-Dernoncourt/pubmed-rct"
      ],
      "metadata": {
        "id": "1Lmr_h4FipFw"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### Similar projects using the dataset:\n",
        "   - Claim Extraction for Scientific Publications 2018: https://github.com/titipata/detecting-scientific-claim"
      ],
      "metadata": {
        "id": "N2d_bTBIi6of"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Abstract** \n",
        "\n",
        "PubMed 200k RCT is new dataset based on PubMed for sequential sentence classification. The dataset consists of approximately 200,000 abstracts of randomized controlled trials, totaling 2.3 million sentences. Each sentence of each abstract is labeled with their role in the abstract using one of the following classes: background, objective, method, result, or conclusion. The purpose of releasing this dataset is twofold. First, the majority of datasets for sequential short-text classification (i.e., classification of short texts that appear in sequences) are small: we hope that releasing a new large dataset will help develop more accurate algorithms for this task. Second, from an application perspective, researchers need better tools to efficiently skim through the literature. Automatically classifying each sentence in an abstract would help researchers read abstracts more efficiently, especially in fields where abstracts may be long, such as the medical field."
      ],
      "metadata": {
        "id": "WXb2Qlu8k2Fk"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Data Dictionary**\n",
        "\n",
        "- `PubMed 20k` is a subset of `PubMed 200k`. I.e., any abstract present in `PubMed 20k` is also present in `PubMed 200k`.\n",
        "- `PubMed_200k_RCT` is the same as `PubMed_200k_RCT_numbers_replaced_with_at_sign`, except that in the latter all numbers had been replaced by `@`. (same for `PubMed_20k_RCT` vs. `PubMed_20k_RCT_numbers_replaced_with_at_sign``).\n",
        "- Since Github file size limit is 100 MiB, we had to compress `PubMed_200k_RCT\\train.7z` and `PubMed_200k_RCT_numbers_replaced_with_at_sign\\train.zip`. \n",
        "- To uncompress `train.7z`, you may use `7-Zip` on Windows, `Keka` on Mac OS X, or `p7zip` on Linux."
      ],
      "metadata": {
        "id": "MvAjGdF0lACB"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Importing data and EDA"
      ],
      "metadata": {
        "id": "T1JweqbejVX4"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!git clone https://github.com/Franck-Dernoncourt/pubmed-rct.git\n",
        "!ls pubmed-rct"
      ],
      "metadata": {
        "id": "zM12MSlAnKCo",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "bb4adf23-506b-4a92-8e86-010efcf4f61f"
      },
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Cloning into 'pubmed-rct'...\n",
            "remote: Enumerating objects: 33, done.\u001b[K\n",
            "remote: Counting objects: 100% (3/3), done.\u001b[K\n",
            "remote: Compressing objects: 100% (3/3), done.\u001b[K\n",
            "remote: Total 33 (delta 0), reused 0 (delta 0), pack-reused 30\u001b[K\n",
            "Unpacking objects: 100% (33/33), done.\n",
            "Checking out files: 100% (13/13), done.\n",
            "PubMed_200k_RCT\n",
            "PubMed_200k_RCT_numbers_replaced_with_at_sign\n",
            "PubMed_20k_RCT\n",
            "PubMed_20k_RCT_numbers_replaced_with_at_sign\n",
            "README.md\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Initial Data exploration and modelling with PubMed_20k dataset"
      ],
      "metadata": {
        "id": "FFpPwS_SqS48"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!ls pubmed-rct/PubMed_20k_RCT_numbers_replaced_with_at_sign"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "I3mLk-7tqk3K",
        "outputId": "9ceb3826-9977-4664-bf2d-de3927555bdf"
      },
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "dev.txt  test.txt  train.txt\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# imports\n",
        "import pandas as pd\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import os\n",
        "import random\n",
        "import tensorflow as tf"
      ],
      "metadata": {
        "id": "TZiWPqoaqzBL"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# functions pre-written for workflow\n",
        "!wget https://raw.githubusercontent.com/hecshzye/nlp-medical-abstract-pubmed-rct/main/helper_functions.py"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "wRZJgoCeFBsX",
        "outputId": "0619c26c-096f-410b-f9c6-aae7ec239b53"
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "--2022-01-17 02:17:42--  https://raw.githubusercontent.com/hecshzye/nlp-medical-abstract-pubmed-rct/main/helper_functions.py\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.108.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 6442 (6.3K) [text/plain]\n",
            "Saving to: ‘helper_functions.py’\n",
            "\n",
            "helper_functions.py 100%[===================>]   6.29K  --.-KB/s    in 0s      \n",
            "\n",
            "2022-01-17 02:17:42 (59.4 MB/s) - ‘helper_functions.py’ saved [6442/6442]\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from helper_functions import create_tensorboard_callback, calculate_results, plot_loss_curves"
      ],
      "metadata": {
        "id": "i6KrTZnaFKtu"
      },
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Function for reading the document\n",
        "def get_doc(filename):\n",
        "  with open(filename, \"r\") as f:\n",
        "    return f.readlines()"
      ],
      "metadata": {
        "id": "JMifgqkxsC5w"
      },
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "data_dir = \"pubmed-rct/PubMed_20k_RCT_numbers_replaced_with_at_sign/\"\n",
        "filenames = [data_dir + filename for filename in os.listdir(data_dir)]\n",
        "filenames"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "RGZi11CDrVxa",
        "outputId": "6a971c96-a1c9-43af-9936-98f0dfc75436"
      },
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['pubmed-rct/PubMed_20k_RCT_numbers_replaced_with_at_sign/dev.txt',\n",
              " 'pubmed-rct/PubMed_20k_RCT_numbers_replaced_with_at_sign/test.txt',\n",
              " 'pubmed-rct/PubMed_20k_RCT_numbers_replaced_with_at_sign/train.txt']"
            ]
          },
          "metadata": {},
          "execution_count": 7
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Preprocessing \n",
        "train_lines = get_doc(data_dir+\"train.txt\")\n",
        "train_lines[:30]"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3ZBDfw7gsX-8",
        "outputId": "ffc2d907-985f-4a5f-9659-f03d3126390d"
      },
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['###24293578\\n',\n",
              " 'OBJECTIVE\\tTo investigate the efficacy of @ weeks of daily low-dose oral prednisolone in improving pain , mobility , and systemic low-grade inflammation in the short term and whether the effect would be sustained at @ weeks in older adults with moderate to severe knee osteoarthritis ( OA ) .\\n',\n",
              " 'METHODS\\tA total of @ patients with primary knee OA were randomized @:@ ; @ received @ mg/day of prednisolone and @ received placebo for @ weeks .\\n',\n",
              " 'METHODS\\tOutcome measures included pain reduction and improvement in function scores and systemic inflammation markers .\\n',\n",
              " 'METHODS\\tPain was assessed using the visual analog pain scale ( @-@ mm ) .\\n',\n",
              " 'METHODS\\tSecondary outcome measures included the Western Ontario and McMaster Universities Osteoarthritis Index scores , patient global assessment ( PGA ) of the severity of knee OA , and @-min walk distance ( @MWD ) .\\n',\n",
              " 'METHODS\\tSerum levels of interleukin @ ( IL-@ ) , IL-@ , tumor necrosis factor ( TNF ) - , and high-sensitivity C-reactive protein ( hsCRP ) were measured .\\n',\n",
              " 'RESULTS\\tThere was a clinically relevant reduction in the intervention group compared to the placebo group for knee pain , physical function , PGA , and @MWD at @ weeks .\\n',\n",
              " 'RESULTS\\tThe mean difference between treatment arms ( @ % CI ) was @ ( @-@ @ ) , p < @ ; @ ( @-@ @ ) , p < @ ; @ ( @-@ @ ) , p < @ ; and @ ( @-@ @ ) , p < @ , respectively .\\n',\n",
              " 'RESULTS\\tFurther , there was a clinically relevant reduction in the serum levels of IL-@ , IL-@ , TNF - , and hsCRP at @ weeks in the intervention group when compared to the placebo group .\\n',\n",
              " 'RESULTS\\tThese differences remained significant at @ weeks .\\n',\n",
              " 'RESULTS\\tThe Outcome Measures in Rheumatology Clinical Trials-Osteoarthritis Research Society International responder rate was @ % in the intervention group and @ % in the placebo group ( p < @ ) .\\n',\n",
              " 'CONCLUSIONS\\tLow-dose oral prednisolone had both a short-term and a longer sustained effect resulting in less knee pain , better physical function , and attenuation of systemic inflammation in older patients with knee OA ( ClinicalTrials.gov identifier NCT@ ) .\\n',\n",
              " '\\n',\n",
              " '###24854809\\n',\n",
              " 'BACKGROUND\\tEmotional eating is associated with overeating and the development of obesity .\\n',\n",
              " 'BACKGROUND\\tYet , empirical evidence for individual ( trait ) differences in emotional eating and cognitive mechanisms that contribute to eating during sad mood remain equivocal .\\n',\n",
              " 'OBJECTIVE\\tThe aim of this study was to test if attention bias for food moderates the effect of self-reported emotional eating during sad mood ( vs neutral mood ) on actual food intake .\\n',\n",
              " 'OBJECTIVE\\tIt was expected that emotional eating is predictive of elevated attention for food and higher food intake after an experimentally induced sad mood and that attentional maintenance on food predicts food intake during a sad versus a neutral mood .\\n',\n",
              " 'METHODS\\tParticipants ( N = @ ) were randomly assigned to one of the two experimental mood induction conditions ( sad/neutral ) .\\n',\n",
              " 'METHODS\\tAttentional biases for high caloric foods were measured by eye tracking during a visual probe task with pictorial food and neutral stimuli .\\n',\n",
              " 'METHODS\\tSelf-reported emotional eating was assessed with the Dutch Eating Behavior Questionnaire ( DEBQ ) and ad libitum food intake was tested by a disguised food offer .\\n',\n",
              " 'RESULTS\\tHierarchical multivariate regression modeling showed that self-reported emotional eating did not account for changes in attention allocation for food or food intake in either condition .\\n',\n",
              " 'RESULTS\\tYet , attention maintenance on food cues was significantly related to increased intake specifically in the neutral condition , but not in the sad mood condition .\\n',\n",
              " 'CONCLUSIONS\\tThe current findings show that self-reported emotional eating ( based on the DEBQ ) might not validly predict who overeats when sad , at least not in a laboratory setting with healthy women .\\n',\n",
              " 'CONCLUSIONS\\tResults further suggest that attention maintenance on food relates to eating motivation when in a neutral affective state , and might therefore be a cognitive mechanism contributing to increased food intake in general , but maybe not during sad mood .\\n',\n",
              " '\\n',\n",
              " '###25165090\\n',\n",
              " 'BACKGROUND\\tAlthough working smoke alarms halve deaths in residential fires , many households do not keep alarms operational .\\n',\n",
              " 'BACKGROUND\\tWe tested whether theory-based education increases alarm operability .\\n']"
            ]
          },
          "metadata": {},
          "execution_count": 8
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### **Data dictionary**\n",
        "\n",
        "`\\t` = tab seperator\n",
        "\n",
        "`\\n` = new line\n",
        "\n",
        "`###` = abstract ID\n",
        "\n",
        "`\"line_number\"` = line position\n",
        "\n",
        "`\"text\"` = text line\n",
        "\n",
        "`\"total_lines\"` = total number of lines in one abstract\n",
        "\n",
        "`\"target\"` = objective of the abstract\n",
        "\n"
      ],
      "metadata": {
        "id": "6ZOKbxQAsyNT"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Function for preprocessing the data\n",
        "\n",
        "def preprocessing_text_with_line_number(filename):\n",
        "  input_lines = get_doc(filename)\n",
        "  abstract_lines = \"\"\n",
        "  abstract_samples = []\n",
        "\n",
        "  for line in input_lines:\n",
        "    if line.startswith(\"###\"):\n",
        "      abstract_id = line\n",
        "      abstract_lines = \"\"\n",
        "    elif line.isspace():\n",
        "      abstract_line_split = abstract_lines.splitlines()\n",
        "       \n",
        "      for abstract_line_number, abstract_line in enumerate(abstract_line_split):\n",
        "        line_data = {}\n",
        "        target_text_split = abstract_line.split(\"\\t\")\n",
        "        line_data[\"target\"] = target_text_split[0]\n",
        "        line_data[\"text\"] = target_text_split[1].lower()\n",
        "        line_data[\"line_number\"] = abstract_line_number\n",
        "        line_data[\"total_lines\"] = len(abstract_line_split) - 1\n",
        "        abstract_samples.append(line_data)\n",
        "\n",
        "    else:\n",
        "      abstract_lines += line \n",
        "  return abstract_samples    "
      ],
      "metadata": {
        "id": "AEl-XZPQttN7"
      },
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Extracting data using the function\n",
        "train_samples = preprocessing_text_with_line_number(data_dir + \"train.txt\")\n",
        "val_samples = preprocessing_text_with_line_number(data_dir + \"dev.txt\")\n",
        "test_samples = preprocessing_text_with_line_number(data_dir + \"test.txt\")\n",
        "len(train_samples), len(test_samples), len(val_samples)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "SCD9DYIU17GH",
        "outputId": "252fdd28-96fa-4cee-d94f-9065a1aef0b3"
      },
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(180040, 30135, 30212)"
            ]
          },
          "metadata": {},
          "execution_count": 10
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "train_samples[:10]"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Hacxnrvq2r7n",
        "outputId": "f756d753-5bd9-401f-db37-f4359c104962"
      },
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[{'line_number': 0,\n",
              "  'target': 'OBJECTIVE',\n",
              "  'text': 'to investigate the efficacy of @ weeks of daily low-dose oral prednisolone in improving pain , mobility , and systemic low-grade inflammation in the short term and whether the effect would be sustained at @ weeks in older adults with moderate to severe knee osteoarthritis ( oa ) .',\n",
              "  'total_lines': 11},\n",
              " {'line_number': 1,\n",
              "  'target': 'METHODS',\n",
              "  'text': 'a total of @ patients with primary knee oa were randomized @:@ ; @ received @ mg/day of prednisolone and @ received placebo for @ weeks .',\n",
              "  'total_lines': 11},\n",
              " {'line_number': 2,\n",
              "  'target': 'METHODS',\n",
              "  'text': 'outcome measures included pain reduction and improvement in function scores and systemic inflammation markers .',\n",
              "  'total_lines': 11},\n",
              " {'line_number': 3,\n",
              "  'target': 'METHODS',\n",
              "  'text': 'pain was assessed using the visual analog pain scale ( @-@ mm ) .',\n",
              "  'total_lines': 11},\n",
              " {'line_number': 4,\n",
              "  'target': 'METHODS',\n",
              "  'text': 'secondary outcome measures included the western ontario and mcmaster universities osteoarthritis index scores , patient global assessment ( pga ) of the severity of knee oa , and @-min walk distance ( @mwd ) .',\n",
              "  'total_lines': 11},\n",
              " {'line_number': 5,\n",
              "  'target': 'METHODS',\n",
              "  'text': 'serum levels of interleukin @ ( il-@ ) , il-@ , tumor necrosis factor ( tnf ) - , and high-sensitivity c-reactive protein ( hscrp ) were measured .',\n",
              "  'total_lines': 11},\n",
              " {'line_number': 6,\n",
              "  'target': 'RESULTS',\n",
              "  'text': 'there was a clinically relevant reduction in the intervention group compared to the placebo group for knee pain , physical function , pga , and @mwd at @ weeks .',\n",
              "  'total_lines': 11},\n",
              " {'line_number': 7,\n",
              "  'target': 'RESULTS',\n",
              "  'text': 'the mean difference between treatment arms ( @ % ci ) was @ ( @-@ @ ) , p < @ ; @ ( @-@ @ ) , p < @ ; @ ( @-@ @ ) , p < @ ; and @ ( @-@ @ ) , p < @ , respectively .',\n",
              "  'total_lines': 11},\n",
              " {'line_number': 8,\n",
              "  'target': 'RESULTS',\n",
              "  'text': 'further , there was a clinically relevant reduction in the serum levels of il-@ , il-@ , tnf - , and hscrp at @ weeks in the intervention group when compared to the placebo group .',\n",
              "  'total_lines': 11},\n",
              " {'line_number': 9,\n",
              "  'target': 'RESULTS',\n",
              "  'text': 'these differences remained significant at @ weeks .',\n",
              "  'total_lines': 11}]"
            ]
          },
          "metadata": {},
          "execution_count": 11
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Creating a DataFrame\n",
        "train_df = pd.DataFrame(train_samples)\n",
        "val_df = pd.DataFrame(val_samples)\n",
        "test_df = pd.DataFrame(test_samples)\n",
        "train_df.head()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        },
        "id": "gEQRm8jZ40lD",
        "outputId": "dd84beae-2b55-49ad-a5a8-9692705471a6"
      },
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "\n",
              "  <div id=\"df-ac58f3f7-1d14-4554-b843-b9194f6498e0\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>target</th>\n",
              "      <th>text</th>\n",
              "      <th>line_number</th>\n",
              "      <th>total_lines</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>OBJECTIVE</td>\n",
              "      <td>to investigate the efficacy of @ weeks of dail...</td>\n",
              "      <td>0</td>\n",
              "      <td>11</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>METHODS</td>\n",
              "      <td>a total of @ patients with primary knee oa wer...</td>\n",
              "      <td>1</td>\n",
              "      <td>11</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>METHODS</td>\n",
              "      <td>outcome measures included pain reduction and i...</td>\n",
              "      <td>2</td>\n",
              "      <td>11</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>METHODS</td>\n",
              "      <td>pain was assessed using the visual analog pain...</td>\n",
              "      <td>3</td>\n",
              "      <td>11</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>METHODS</td>\n",
              "      <td>secondary outcome measures included the wester...</td>\n",
              "      <td>4</td>\n",
              "      <td>11</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-ac58f3f7-1d14-4554-b843-b9194f6498e0')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "        \n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "      \n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-ac58f3f7-1d14-4554-b843-b9194f6498e0 button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-ac58f3f7-1d14-4554-b843-b9194f6498e0');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n",
              "  "
            ],
            "text/plain": [
              "      target  ... total_lines\n",
              "0  OBJECTIVE  ...          11\n",
              "1    METHODS  ...          11\n",
              "2    METHODS  ...          11\n",
              "3    METHODS  ...          11\n",
              "4    METHODS  ...          11\n",
              "\n",
              "[5 rows x 4 columns]"
            ]
          },
          "metadata": {},
          "execution_count": 12
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "train_df.target.value_counts()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "nAydvGtt7lHx",
        "outputId": "2075c73b-3536-4b90-ac71-2c5082d373b9"
      },
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "METHODS        59353\n",
              "RESULTS        57953\n",
              "CONCLUSIONS    27168\n",
              "BACKGROUND     21727\n",
              "OBJECTIVE      13839\n",
              "Name: target, dtype: int64"
            ]
          },
          "metadata": {},
          "execution_count": 13
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "train_df.total_lines.plot.hist();"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 267
        },
        "id": "GHq-mW6b7tet",
        "outputId": "eadcaaee-4d6c-4c31-9d65-cac20587bb6d"
      },
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZEAAAD6CAYAAABgZXp6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAXpUlEQVR4nO3df7BfdX3n8efLRCpSkVDSLJNgg21Gl7r+gCvg1HatjCHg1tBdl4WtS5ZhiDNgV8f9QXQ6i8Uyk+5spdJatqlkTVwV8SfZEppGxHb7Bz+CIAjo5IqwJAJJDRDRFhZ97x/fz5Wv4ebyzbn53i/35vmY+c49530+55zPZ74TXpxzPt/vN1WFJEldvGjUHZAkzV6GiCSpM0NEktSZISJJ6swQkSR1ZohIkjobWogkeVWSO/tee5O8L8nRSbYm2d7+Lmjtk+TKJONJ7kpyYt+xVrX225Os6quflOTuts+VSTKs8UiSnisz8TmRJPOAncApwMXAnqpam2QNsKCqLklyJvC7wJmt3Uer6pQkRwPbgDGggNuBk6rqsSS3Av8BuAXYDFxZVTdM1Zdjjjmmli5dOpRxStJcdPvtt/99VS2cbNv8GerDacB3qurBJCuBt7T6BuBrwCXASmBj9VLt5iRHJTm2td1aVXsAkmwFViT5GnBkVd3c6huBs4ApQ2Tp0qVs27bt4I5OkuawJA/ub9tMPRM5B/hMW15UVQ+35UeARW15MfBQ3z47Wm2q+o5J6pKkGTL0EElyGPAO4HP7bmtXHUO/n5ZkdZJtSbbt3r172KeTpEPGTFyJnAF8vaoebeuPtttUtL+7Wn0ncFzffktabar6kknqz1FV66pqrKrGFi6c9LaeJKmDmQiRc3n2VhbAJmBihtUq4Lq++nltltapwBPtttcWYHmSBW0m13JgS9u2N8mpbVbWeX3HkiTNgKE+WE9yBPA24N195bXAtUkuAB4Ezm71zfRmZo0DPwLOB6iqPUk+DNzW2l028ZAduAj4BHA4vQfqUz5UlyQdXDMyxfeFZGxsrJydJUmDS3J7VY1Nts1PrEuSOjNEJEmdGSKSpM5m6hPrmqWWrrl+JOd9YO3bR3JeSQfGKxFJUmeGiCSpM0NEktSZISJJ6swQkSR1ZohIkjozRCRJnRkikqTODBFJUmeGiCSpM0NEktSZISJJ6swQkSR1ZohIkjozRCRJnRkikqTODBFJUmeGiCSps6GGSJKjknw+ybeS3JfkTUmOTrI1yfb2d0FrmyRXJhlPcleSE/uOs6q1355kVV/9pCR3t32uTJJhjkeS9LOGfSXyUeCvqurVwOuA+4A1wI1VtQy4sa0DnAEsa6/VwFUASY4GLgVOAU4GLp0Intbmwr79Vgx5PJKkPkMLkSQvB34DuBqgqp6uqseBlcCG1mwDcFZbXglsrJ6bgaOSHAucDmytqj1V9RiwFVjRth1ZVTdXVQEb+44lSZoBw7wSOR7YDfzPJHck+XiSI4BFVfVwa/MIsKgtLwYe6tt/R6tNVd8xSV2SNEOGGSLzgROBq6rqDcAPefbWFQDtCqKG2AcAkqxOsi3Jtt27dw/7dJJ0yBhmiOwAdlTVLW398/RC5dF2K4r2d1fbvhM4rm//Ja02VX3JJPXnqKp1VTVWVWMLFy6c1qAkSc8aWohU1SPAQ0le1UqnAfcCm4CJGVargOva8ibgvDZL61TgiXbbawuwPMmC9kB9ObClbdub5NQ2K+u8vmNJkmbA/CEf/3eBTyU5DLgfOJ9ecF2b5ALgQeDs1nYzcCYwDvyotaWq9iT5MHBba3dZVe1pyxcBnwAOB25oL0nSDBlqiFTVncDYJJtOm6RtARfv5zjrgfWT1LcBr5lmNyVJHfmJdUlSZ4aIJKkzQ0SS1JkhIknqzBCRJHVmiEiSOjNEJEmdGSKSpM4MEUlSZ4aIJKkzQ0SS1JkhIknqzBCRJHVmiEiSOjNEJEmdGSKSpM4MEUlSZ4aIJKkzQ0SS1JkhIknqzBCRJHVmiEiSOhtqiCR5IMndSe5Msq3Vjk6yNcn29ndBqyfJlUnGk9yV5MS+46xq7bcnWdVXP6kdf7ztm2GOR5L0s2biSuQ3q+r1VTXW1tcAN1bVMuDGtg5wBrCsvVYDV0EvdIBLgVOAk4FLJ4Kntbmwb78Vwx+OJGnCKG5nrQQ2tOUNwFl99Y3VczNwVJJjgdOBrVW1p6oeA7YCK9q2I6vq5qoqYGPfsSRJM2DYIVLAXye5PcnqVltUVQ+35UeARW15MfBQ3747Wm2q+o5J6s+RZHWSbUm27d69ezrjkST1mT/k47+5qnYm+UVga5Jv9W+sqkpSQ+4DVbUOWAcwNjY29PNJ0qFiqFciVbWz/d0FfIneM41H260o2t9drflO4Li+3Ze02lT1JZPUJUkzZGghkuSIJC+bWAaWA98ENgETM6xWAde15U3AeW2W1qnAE+221xZgeZIF7YH6cmBL27Y3yaltVtZ5fceSJM2AYd7OWgR8qc26nQ98uqr+KsltwLVJLgAeBM5u7TcDZwLjwI+A8wGqak+SDwO3tXaXVdWetnwR8AngcOCG9pIkzZChhUhV3Q+8bpL694HTJqkXcPF+jrUeWD9JfRvwmml3VpLUiZ9YlyR1ZohIkjozRCRJnRkikqTODBFJUmeGiCSpM0NEktSZISJJ6swQkSR1ZohIkjozRCRJnRkikqTODBFJUmeGiCSpM0NEktTZQCGS5J8NuyOSpNln0CuRP0tya5KLkrx8qD2SJM0aA4VIVf068DvAccDtST6d5G1D7Zkk6QVv4GciVbUd+D3gEuCfA1cm+VaSfzmszkmSXtgGfSby2iRXAPcBbwV+q6r+aVu+Yoj9kyS9gM0fsN2fAB8HPlhV/zBRrKrvJfm9ofRMkvSCN+jtrLcDn54IkCQvSvJSgKr65FQ7JpmX5I4kf9nWj09yS5LxJJ9Nclir/1xbH2/bl/Yd4wOt/u0kp/fVV7TaeJI1BzJwSdL0DRoiXwEO71t/aasN4r30boNN+EPgiqr6FeAx4IJWvwB4rNWvaO1IcgJwDvCrwAp6M8XmJZkHfAw4AzgBOLe1lSTNkEFvZ72kqp6cWKmqJyeuRKaSZAm9q5jLgfcnCb3nKP+2NdkAfAi4CljZlgE+D/xpa78SuKaqngK+m2QcOLm1G6+q+9u5rmlt7x1wTHoBW7rm+pGd+4G1bx/ZuaXZZtArkR8mOXFiJclJwD9M0X7CHwP/BfhJW/8F4PGqeqat7wAWt+XFwEMAbfsTrf1P6/vss7+6JGmGDHol8j7gc0m+BwT4J8C/mWqHJP8C2FVVtyd5y7R6OU1JVgOrAV7xileMsiuSNKcMFCJVdVuSVwOvaqVvV9X/e57dfg14R5IzgZcARwIfBY5KMr9dbSwBdrb2O+l9mHFHkvnAy4Hv99Un9O+zv/q+/V8HrAMYGxur5+m3JGlAB/IFjG8EXgucSO8h9nlTNa6qD1TVkqpaSu/B+Fer6neAm4B3tmargOva8qa2Ttv+1aqqVj+nzd46HlgG3ArcBixrs70Oa+fYdADjkSRN00BXIkk+CfwycCfw41YuYGOHc14CXJPkD4A7gKtb/Wrgk+3B+R56oUBV3ZPkWnoPzJ8BLq6qH7d+vQfYAswD1lfVPR36I0nqaNBnImPACe3K4IBV1deAr7Xl+3l2dlV/m38E/vV+9r+c3gyvfeubgc1d+iRJmr5Bb2d9k97DdEmSfmrQK5FjgHuT3Ao8NVGsqncMpVeSpFlh0BD50DA7IUmanQad4vs3SX4JWFZVX2mfVp833K5Jkl7oBv0q+AvpfRXJn7fSYuDLw+qUJGl2GPTB+sX0Pjy4F376A1W/OKxOSZJmh0FD5KmqenpipX2i3E9+S9IhbtAQ+ZskHwQOb7+t/jngfw+vW5Kk2WDQEFkD7AbuBt5N7wN+/qKhJB3iBp2d9RPgL9pLkiRg8O/O+i6TPAOpqlce9B5JkmaNA/nurAkvofcdV0cf/O5IkmaTgZ6JVNX3+147q+qP6f3srSTpEDbo7awT+1ZfRO/KZNCrGEnSHDVoEPxR3/IzwAPA2Qe9N5KkWWXQ2Vm/OeyOSJJmn0FvZ71/qu1V9ZGD0x1J0mxyILOz3sizv2H+W/R+53z7MDoljdLSNdeP5LwPrHWuimafQUNkCXBiVf0AIMmHgOur6l3D6pgk6YVv0K89WQQ83bf+dKtJkg5hg16JbARuTfKltn4WsGE4XZIkzRaDzs66PMkNwK+30vlVdcfwuiVJmg0GvZ0F8FJgb1V9FNiR5PipGid5SZJbk3wjyT1Jfr/Vj09yS5LxJJ9Nclir/1xbH2/bl/Yd6wOt/u0kp/fVV7TaeJI1BzAWSdJBMOjP414KXAJ8oJVeDPyv59ntKeCtVfU64PXAiiSnAn8IXFFVvwI8BlzQ2l8APNbqV7R2JDkBOAf4VWAF8GdJ5iWZB3wMOAM4ATi3tZUkzZBBr0R+G3gH8EOAqvoe8LKpdqieJ9vqi9urgLfS+7126D1XOastr+TZ5yyfB05Lkla/pqqeqqrvAuPAye01XlX3t19dvKa1lSTNkEFD5OmqKtrXwSc5YpCd2hXDncAuYCvwHeDxqnqmNdkBLG7Li4GHANr2J4Bf6K/vs8/+6pKkGTJoiFyb5M+Bo5JcCHyFAX6gqqp+XFWvp/c5k5OBV3fu6TQkWZ1kW5Jtu3fvHkUXJGlOet7ZWe2W0mfpBcBe4FXAf62qrYOepKoeT3IT8CZ6QTS/XW0sAXa2ZjuB4+g9tJ8PvBz4fl99Qv8++6vve/51wDqAsbGx5/y4liSpm+e9Emm3sTZX1daq+s9V9Z8GCZAkC5Mc1ZYPB94G3AfcBLyzNVsFXNeWN7V12vavtnNvAs5ps7eOB5bR+8qV24BlbbbXYfQevk98LYskaQYM+mHDryd5Y1XddgDHPhbY0GZRvQi4tqr+Msm9wDVJ/gC4A7i6tb8a+GSScWAPvVCgqu5Jci1wL72vob+4qn4MkOQ9wBZgHrC+qu45gP5JkqZp0BA5BXhXkgfozdAKvYuU1+5vh6q6C3jDJPX76T0f2bf+j/R+dneyY10OXD5JfTOwebAhSJIOtilDJMkrqur/AqdP1U6SdGh6viuRL9P79t4Hk3yhqv7VTHRKkjQ7PN+D9fQtv3KYHZEkzT7PFyK1n2VJkp73dtbrkuyld0VyeFuGZx+sHznU3kmSXtCmDJGqmjdTHZEkzT4H8lXwkiT9DENEktSZISJJ6swQkSR1ZohIkjozRCRJnRkikqTODBFJUmeGiCSpM0NEktTZoD9KpRFauub6UXdBkibllYgkqTNDRJLUmSEiSerMEJEkdWaISJI6G1qIJDkuyU1J7k1yT5L3tvrRSbYm2d7+Lmj1JLkyyXiSu5Kc2HesVa399iSr+uonJbm77XNlkjy3J5KkYRnmlcgzwH+sqhOAU4GLk5wArAFurKplwI1tHeAMYFl7rQaugl7oAJcCpwAnA5dOBE9rc2HffiuGOB5J0j6GFiJV9XBVfb0t/wC4D1gMrAQ2tGYbgLPa8kpgY/XcDByV5FjgdGBrVe2pqseArcCKtu3Iqrq5qgrY2HcsSdIMmJFnIkmWAm8AbgEWVdXDbdMjwKK2vBh4qG+3Ha02VX3HJPXJzr86ybYk23bv3j2tsUiSnjX0EEny88AXgPdV1d7+be0Koobdh6paV1VjVTW2cOHCYZ9Okg4ZQw2RJC+mFyCfqqovtvKj7VYU7e+uVt8JHNe3+5JWm6q+ZJK6JGmGDHN2VoCrgfuq6iN9mzYBEzOsVgHX9dXPa7O0TgWeaLe9tgDLkyxoD9SXA1vatr1JTm3nOq/vWJKkGTDML2D8NeDfAXcnubPVPgisBa5NcgHwIHB227YZOBMYB34EnA9QVXuSfBi4rbW7rKr2tOWLgE8AhwM3tJckaYYMLUSq6u+A/X1u47RJ2hdw8X6OtR5YP0l9G/CaaXRTkjQNfmJdktSZISJJ6swQkSR1ZohIkjozRCRJnRkikqTODBFJUmeGiCSpM0NEktSZISJJ6swQkSR1ZohIkjozRCRJnRkikqTODBFJUmeGiCSpM0NEktSZISJJ6swQkSR1ZohIkjozRCRJnQ0tRJKsT7IryTf7akcn2Zpke/u7oNWT5Mok40nuSnJi3z6rWvvtSVb11U9Kcnfb58okGdZYJEmTmz/EY38C+FNgY19tDXBjVa1NsqatXwKcASxrr1OAq4BTkhwNXAqMAQXcnmRTVT3W2lwI3AJsBlYANwxxPNJQLV1z/UjO+8Dat4/kvJobhnYlUlV/C+zZp7wS2NCWNwBn9dU3Vs/NwFFJjgVOB7ZW1Z4WHFuBFW3bkVV1c1UVvaA6C0nSjJrpZyKLqurhtvwIsKgtLwYe6mu3o9Wmqu+YpC5JmkEje7DeriBqJs6VZHWSbUm27d69eyZOKUmHhJkOkUfbrSja312tvhM4rq/dklabqr5kkvqkqmpdVY1V1djChQunPQhJUs9Mh8gmYGKG1Srgur76eW2W1qnAE+221xZgeZIFbSbXcmBL27Y3yaltVtZ5fceSJM2Qoc3OSvIZ4C3AMUl20JtltRa4NskFwIPA2a35ZuBMYBz4EXA+QFXtSfJh4LbW7rKqmnhYfxG9GWCH05uV5cwsSZphQwuRqjp3P5tOm6RtARfv5zjrgfWT1LcBr5lOHyVJ0+Mn1iVJnRkikqTODBFJUmeGiCSpM0NEktSZISJJ6swQkSR1ZohIkjozRCRJnRkikqTODBFJUmeGiCSpM0NEktSZISJJ6swQkSR1ZohIkjozRCRJnRkikqTODBFJUmeGiCSps/mj7oCk0Vq65vqRnfuBtW8f2bl1cHglIknqbNZfiSRZAXwUmAd8vKrWDutco/w/NmkuGtW/Ka+ADp5ZfSWSZB7wMeAM4ATg3CQnjLZXknTomNUhApwMjFfV/VX1NHANsHLEfZKkQ8Zsv521GHiob30HcMqI+iJplnAywcEz20NkIElWA6vb6pNJvj3K/kziGODvR92JIZvrY3R8s9+MjDF/OOwz7Nd0xvdL+9sw20NkJ3Bc3/qSVvsZVbUOWDdTnTpQSbZV1dio+zFMc32Mjm/2m+tjHNb4ZvszkduAZUmOT3IYcA6wacR9kqRDxqy+EqmqZ5K8B9hCb4rv+qq6Z8TdkqRDxqwOEYCq2gxsHnU/pukFe6vtIJrrY3R8s99cH+NQxpeqGsZxJUmHgNn+TESSNEKGyIgleSDJ3UnuTLJt1P05GJKsT7IryTf7akcn2Zpke/u7YJR9nI79jO9DSXa29/HOJGeOso/TkeS4JDcluTfJPUne2+pz4j2cYnxz6T18SZJbk3yjjfH3W/34JLckGU/y2TYhaXrn8nbWaCV5ABirqjkzBz/JbwBPAhur6jWt9t+APVW1NskaYEFVXTLKfna1n/F9CHiyqv77KPt2MCQ5Fji2qr6e5GXA7cBZwL9nDryHU4zvbObOexjgiKp6MsmLgb8D3gu8H/hiVV2T5H8A36iqq6ZzLq9EdNBV1d8Ce/YprwQ2tOUN9P7Rzkr7Gd+cUVUPV9XX2/IPgPvofTvEnHgPpxjfnFE9T7bVF7dXAW8FPt/qB+U9NERGr4C/TnJ7+2T9XLWoqh5uy48Ai0bZmSF5T5K72u2uWXmrZ19JlgJvAG5hDr6H+4wP5tB7mGRekjuBXcBW4DvA41X1TGuyg4MQnobI6L25qk6k903EF7dbJXNa9e6hzrX7qFcBvwy8HngY+KPRdmf6kvw88AXgfVW1t3/bXHgPJxnfnHoPq+rHVfV6et/kcTLw6mGcxxAZsara2f7uAr5E782eix5t96In7knvGnF/DqqqerT9o/0J8BfM8vex3Uf/AvCpqvpiK8+Z93Cy8c2193BCVT0O3AS8CTgqycTnAyf9mqgDZYiMUJIj2oM9khwBLAe+OfVes9YmYFVbXgVcN8K+HHQT/3FtfptZ/D62h7JXA/dV1Uf6Ns2J93B/45tj7+HCJEe15cOBt9F79nMT8M7W7KC8h87OGqEkr6R39QG9bw/4dFVdPsIuHRRJPgO8hd63hj4KXAp8GbgWeAXwIHB2Vc3Kh9P7Gd9b6N0GKeAB4N19zw9mlSRvBv4PcDfwk1b+IL3nBrP+PZxifOcyd97D19J7cD6P3sXCtVV1WftvzjXA0cAdwLuq6qlpncsQkSR15e0sSVJnhogkqTNDRJLUmSEiSerMEJEkdWaISJI6M0QkSZ0ZIpKkzv4/2LyLCkd/AwYAAAAASUVORK5CYII=\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "train_df.line_number.plot.hist();"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 265
        },
        "id": "9aujM0Sa72NG",
        "outputId": "b5deb884-94c6-4396-8af9-43df4d7f8852"
      },
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZEAAAD4CAYAAAAtrdtxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAASwElEQVR4nO3df9CdZX3n8ffHAAVtFShZliHQYM3UTV2rGIGO7a6LIwZphXbVwtQ16zCmM+KMTveH0eks1pYZ3NkWS0fd0pJpcNtGqlayBYeNiv3xBz+CoAiU8hTDkoiQGhCpFjb43T/O9cAxPnlyciXnOc/J837NnHnu+3tf97mva+7kfOb+ce6TqkKSpB7Pm3QHJEnTyxCRJHUzRCRJ3QwRSVI3Q0SS1O2ISXdgoZ1wwgm1cuXKSXdDkqbG7bff/o9VtXyuZUsuRFauXMm2bdsm3Q1JmhpJHtzXMk9nSZK6GSKSpG6GiCSpmyEiSepmiEiSuhkikqRuhogkqZshIknqZohIkrotuW+sH4yVG66fdBcW3PbLz5t0FyQtYh6JSJK6GSKSpG6GiCSpmyEiSepmiEiSuhkikqRuhogkqZshIknqZohIkroZIpKkboaIJKmbz87SvCb1vDCf2SVNB49EJEndDBFJUjdDRJLUzRCRJHUzRCRJ3QwRSVI3Q0SS1G3sIZJkWZI7kvxlmz8tyS1JZpJ8MslRrf4jbX6mLV859B7vb/X7krxhqL621WaSbBj3WCRJP2ghjkTeA9w7NP9h4IqqegnwGHBxq18MPNbqV7R2JFkNXAj8NLAW+FgLpmXAR4FzgdXARa2tJGmBjDVEkqwAzgP+qM0HOBv4VGuyCbigTZ/f5mnLX9fanw9srqqnqurrwAxwRnvNVNUDVfU0sLm1lSQtkHEfiXwE+K/A99v8jwOPV9WeNr8DOLlNnww8BNCWf7u1f7a+1zr7qv+QJOuTbEuybdeuXQc7JklSM7YQSfILwKNVdfu4tjGqqrqqqtZU1Zrly5dPujuSdNgY5wMYXwO8KckbgaOBFwK/Bxyb5Ih2tLEC2Nna7wROAXYkOQJ4EfCtofqs4XX2VZckLYCxHYlU1furakVVrWRwYfyLVfWrwE3Am1uzdcB1bXpLm6ct/2JVVatf2O7eOg1YBdwK3Aasand7HdW2sWVc45Ek/bBJPAr+fcDmJL8N3AFc3epXA59IMgPsZhAKVNXdSa4F7gH2AJdU1TMASd4N3AgsAzZW1d0LOhJJWuIWJESq6kvAl9r0AwzurNq7zT8Db9nH+pcBl81RvwG44RB2VZJ0APzGuiSpmyEiSepmiEiSuhkikqRuhogkqZshIknqZohIkroZIpKkboaIJKmbISJJ6maISJK6GSKSpG6GiCSpmyEiSepmiEiSuhkikqRuhogkqZshIknqZohIkroZIpKkboaIJKmbISJJ6maISJK6GSKSpG6GiCSpmyEiSepmiEiSuhkikqRuhogkqZshIknqZohIkroZIpKkboaIJKmbISJJ6maISJK6GSKSpG6GiCSp29hCJMnRSW5N8pUkdyf5zVY/LcktSWaSfDLJUa3+I21+pi1fOfRe72/1+5K8Yai+ttVmkmwY11gkSXMb55HIU8DZVfUzwCuAtUnOAj4MXFFVLwEeAy5u7S8GHmv1K1o7kqwGLgR+GlgLfCzJsiTLgI8C5wKrgYtaW0nSAhlbiNTAk232yPYq4GzgU62+CbigTZ/f5mnLX5ckrb65qp6qqq8DM8AZ7TVTVQ9U1dPA5tZWkrRAjhjnm7ejhduBlzA4avgH4PGq2tOa7ABObtMnAw8BVNWeJN8GfrzVbx562+F1HtqrfuY++rEeWA9w6qmnHtygtCBWbrh+Ytvefvl5E9u2NG3GemG9qp6pqlcAKxgcObx0nNubpx9XVdWaqlqzfPnySXRBkg5LC3J3VlU9DtwE/CxwbJLZI6AVwM42vRM4BaAtfxHwreH6Xuvsqy5JWiDjvDtreZJj2/QxwOuBexmEyZtbs3XAdW16S5unLf9iVVWrX9ju3joNWAXcCtwGrGp3ex3F4OL7lnGNR5L0w8Z5TeQkYFO7LvI84Nqq+ssk9wCbk/w2cAdwdWt/NfCJJDPAbgahQFXdneRa4B5gD3BJVT0DkOTdwI3AMmBjVd09xvFIkvYythCpqq8Cr5yj/gCD6yN71/8ZeMs+3usy4LI56jcANxx0ZyVJXUY6nZXkX4+7I5Kk6TPqNZGPtW+fvyvJi8baI0nS1BgpRKrq54FfZXA31O1J/jTJ68faM0nSojfy3VlVdT/wG8D7gH8LXJnk75L88rg6J0la3Ea9JvLyJFcwuEX3bOAXq+pftekrxtg/SdIiNurdWb8P/BHwgar63myxqr6R5DfG0jNJ0qI3aoicB3xv6PsZzwOOrqrvVtUnxtY7SdKiNuo1kc8DxwzNP7/VJElL2KghcvTQY91p088fT5ckSdNi1BD5pySnz84keRXwvXnaS5KWgFGvibwX+PMk3wAC/EvgV8bWK0nSVBgpRKrqtiQvBX6qle6rqv83vm5JkqbBgTyA8dXAyrbO6UmoqmvG0itJ0lQYKUSSfAL4SeBO4JlWLsAQkaQlbNQjkTXA6vYjUZIkAaPfnfU1BhfTJUl61qhHIicA9yS5FXhqtlhVbxpLryRJU2HUEPngODshSZpOo97i+1dJfgJYVVWfT/J8Br9rLklawkZ9FPw7gU8Bf9BKJwOfHVenJEnTYdQL65cArwGegGd/oOpfjKtTkqTpMGqIPFVVT8/OJDmCwfdEJElL2Kgh8ldJPgAc035b/c+B/z2+bkmSpsGoIbIB2AXcBfwacAOD31uXJC1ho96d9X3gD9tLkiRg9GdnfZ05roFU1YsPeY8kSVPjQJ6dNeto4C3A8Ye+O5KkaTLSNZGq+tbQa2dVfQQ4b8x9kyQtcqOezjp9aPZ5DI5MDuS3SCRJh6FRg+B3hqb3ANuBtx7y3kiSpsqod2f9u3F3RJI0fUY9nfXr8y2vqt89NN2RJE2TA7k769XAljb/i8CtwP3j6JQkaTqMGiIrgNOr6jsAST4IXF9VbxtXxyRJi9+ojz05EXh6aP7pVpMkLWGjHolcA9ya5C/a/AXApvF0SZI0LUa9O+uyJJ8Dfr6V3lFVd4yvW5KkaTDq6SyA5wNPVNXvATuSnDZf4ySnJLkpyT1J7k7ynlY/PsnWJPe3v8e1epJcmWQmyVeHv+CYZF1rf3+SdUP1VyW5q61zZZIc0OglSQdl1J/HvRR4H/D+VjoS+F/7WW0P8J+qajVwFnBJktUMHiv/hapaBXyhzQOcC6xqr/XAx9u2jwcuBc4EzgAunQ2e1uadQ+utHWU8kqRDY9QjkV8C3gT8E0BVfQP4sflWqKqHq+rLbfo7wL0Mfpv9fJ67nrKJwfUVWv2aGrgZODbJScAbgK1VtbuqHgO2AmvbshdW1c1VVQyu28y+lyRpAYwaIk+3D+oCSPKCA9lIkpXAK4FbgBOr6uG26Js8d5fXycBDQ6vtaLX56jvmqM+1/fVJtiXZtmvXrgPpuiRpHqOGyLVJ/oDB0cE7gc8z4g9UJflR4NPAe6vqieFlw8E0TlV1VVWtqao1y5cvH/fmJGnJ2O/dWe1i9SeBlwJPAD8F/Leq2jrCukcyCJA/qarPtPIjSU6qqofbKalHW30ncMrQ6itabSfw2r3qX2r1FXO0lyQtkP0eibSjhRuqamtV/Zeq+s8jBkiAq4F793q21hZg9g6rdcB1Q/W3t7u0zgK+3U573Qick+S4dkH9HODGtuyJJGe1bb196L0kSQtg1C8bfjnJq6vqtgN479cA/wG4K8mdrfYB4HIGp8cuBh7kuUfK3wC8EZgBvgu8A6Cqdif5LWB22x+qqt1t+l3AHwPHAJ9rL0nSAhk1RM4E3pZkO4M7tMLgIOXl+1qhqv62tZvL6+ZoX8Al+3ivjcDGOerbgJftr/OSpPGYN0SSnFpV/5fBbbaSJP2A/R2JfJbB03sfTPLpqvr3C9EpSdJ02N+F9eHTUS8eZ0ckSdNnfyFS+5iWJGm/p7N+JskTDI5IjmnT8NyF9ReOtXeSpEVt3hCpqmUL1RFJ0vQ5kEfBS5L0AwwRSVI3Q0SS1M0QkSR1M0QkSd0MEUlSN0NEktTNEJEkdTNEJEndDBFJUjdDRJLUzRCRJHUzRCRJ3QwRSVI3Q0SS1M0QkSR1M0QkSd0MEUlSN0NEktTNEJEkdTNEJEndjph0B6TFZuWG6yey3e2XnzeR7UoHwyMRSVI3Q0SS1M0QkSR1M0QkSd0MEUlSN0NEktTNEJEkdTNEJEndDBFJUrexhUiSjUkeTfK1odrxSbYmub/9Pa7Vk+TKJDNJvprk9KF11rX29ydZN1R/VZK72jpXJsm4xiJJmts4j0T+GFi7V20D8IWqWgV8oc0DnAusaq/1wMdhEDrApcCZwBnApbPB09q8c2i9vbclSRqzsYVIVf01sHuv8vnApja9CbhgqH5NDdwMHJvkJOANwNaq2l1VjwFbgbVt2Qur6uaqKuCaofeSJC2Qhb4mcmJVPdymvwmc2KZPBh4aarej1ear75ijPqck65NsS7Jt165dBzcCSdKzJnZhvR1B1AJt66qqWlNVa5YvX74Qm5SkJWGhQ+SRdiqK9vfRVt8JnDLUbkWrzVdfMUddkrSAFjpEtgCzd1itA64bqr+93aV1FvDtdtrrRuCcJMe1C+rnADe2ZU8kOavdlfX2ofeSJC2Qsf0oVZI/A14LnJBkB4O7rC4Hrk1yMfAg8NbW/AbgjcAM8F3gHQBVtTvJbwG3tXYfqqrZi/XvYnAH2DHA59pLkrSAxhYiVXXRPha9bo62BVyyj/fZCGyco74NeNnB9FGSdHD8xrokqZshIknqZohIkroZIpKkboaIJKmbISJJ6maISJK6GSKSpG6GiCSpmyEiSepmiEiSuhkikqRuhogkqZshIknqZohIkroZIpKkboaIJKmbISJJ6maISJK6GSKSpG6GiCSpmyEiSep2xKQ7IGlg5YbrJ7Ld7ZefN5Ht6vDgkYgkqZshIknqZohIkroZIpKkboaIJKmbISJJ6maISJK6GSKSpG6GiCSpmyEiSepmiEiSuhkikqRuhogkqZtP8ZWWuEk9PRh8gvDhYOqPRJKsTXJfkpkkGybdH0laSqY6RJIsAz4KnAusBi5KsnqyvZKkpWPaT2edAcxU1QMASTYD5wP3TLRXkkbiD3FNv2kPkZOBh4bmdwBn7t0oyXpgfZt9Msl9nds7AfjHznUXm8NlLIfLOMCxLJh8eOSmi3ocB+hgxvIT+1ow7SEykqq6CrjqYN8nybaqWnMIujRxh8tYDpdxgGNZjA6XccD4xjLV10SAncApQ/MrWk2StACmPURuA1YlOS3JUcCFwJYJ90mSloypPp1VVXuSvBu4EVgGbKyqu8e4yYM+JbaIHC5jOVzGAY5lMTpcxgFjGkuqahzvK0laAqb9dJYkaYIMEUlSN0NkBIfTo1WSbE9yV5I7k2ybdH8ORJKNSR5N8rWh2vFJtia5v/09bpJ9HNU+xvLBJDvbvrkzyRsn2cdRJDklyU1J7klyd5L3tPrU7Zd5xjKN++XoJLcm+Uoby2+2+mlJbmmfZZ9sNyQd3La8JjK/9miVvwdez+DLjLcBF1XVVH4rPsl2YE1VTd0XqJL8G+BJ4Jqqelmr/Xdgd1Vd3gL+uKp63yT7OYp9jOWDwJNV9T8m2bcDkeQk4KSq+nKSHwNuBy4A/iNTtl/mGctbmb79EuAFVfVkkiOBvwXeA/w68Jmq2pzkfwJfqaqPH8y2PBLZv2cfrVJVTwOzj1bRAquqvwZ271U+H9jUpjcx+E+/6O1jLFOnqh6uqi+36e8A9zJ4ksTU7Zd5xjJ1auDJNntkexVwNvCpVj8k+8UQ2b+5Hq0ylf+wmgL+T5Lb2+Ngpt2JVfVwm/4mcOIkO3MIvDvJV9vprkV/CmhYkpXAK4FbmPL9stdYYAr3S5JlSe4EHgW2Av8APF5Ve1qTQ/JZZogsPT9XVaczePLxJe20ymGhBudmp/n87MeBnwReATwM/M5kuzO6JD8KfBp4b1U9Mbxs2vbLHGOZyv1SVc9U1SsYPMnjDOCl49iOIbJ/h9WjVapqZ/v7KPAXDP5xTbNH2rns2XPaj064P92q6pH2H//7wB8yJfumnXP/NPAnVfWZVp7K/TLXWKZ1v8yqqseBm4CfBY5NMvsl80PyWWaI7N9h82iVJC9oFwxJ8gLgHOBr86+16G0B1rXpdcB1E+zLQZn90G1+iSnYN+0C7tXAvVX1u0OLpm6/7GssU7pflic5tk0fw+DGoHsZhMmbW7NDsl+8O2sE7Za+j/Dco1Uum3CXuiR5MYOjDxg88uZPp2ksSf4MeC2DR1o/AlwKfBa4FjgVeBB4a1Ut+gvW+xjLaxmcMilgO/BrQ9cVFqUkPwf8DXAX8P1W/gCDawlTtV/mGctFTN9+eTmDC+fLGBwsXFtVH2qfAZuB44E7gLdV1VMHtS1DRJLUy9NZkqRuhogkqZshIknqZohIkroZIpKkboaIJKmbISJJ6vb/AVwSphAAsBgmAAAAAElFTkSuQmCC\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# List of sentences (abstract text lines -> lists)\n",
        "train_sentences = train_df[\"text\"].tolist()\n",
        "val_sentences = val_df[\"text\"].tolist()\n",
        "test_sentences = test_df[\"text\"].tolist()\n",
        "len(train_sentences), len(val_sentences), len(test_sentences)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Ia2ch3u_8CrC",
        "outputId": "5372f558-2967-43eb-c03b-2eaee125fbfe"
      },
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(180040, 30212, 30135)"
            ]
          },
          "metadata": {},
          "execution_count": 16
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "train_sentences[:20]"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "l6V8HT4a8OX_",
        "outputId": "37ced6ff-c580-4e8f-9c0e-987115fb5a74"
      },
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['to investigate the efficacy of @ weeks of daily low-dose oral prednisolone in improving pain , mobility , and systemic low-grade inflammation in the short term and whether the effect would be sustained at @ weeks in older adults with moderate to severe knee osteoarthritis ( oa ) .',\n",
              " 'a total of @ patients with primary knee oa were randomized @:@ ; @ received @ mg/day of prednisolone and @ received placebo for @ weeks .',\n",
              " 'outcome measures included pain reduction and improvement in function scores and systemic inflammation markers .',\n",
              " 'pain was assessed using the visual analog pain scale ( @-@ mm ) .',\n",
              " 'secondary outcome measures included the western ontario and mcmaster universities osteoarthritis index scores , patient global assessment ( pga ) of the severity of knee oa , and @-min walk distance ( @mwd ) .',\n",
              " 'serum levels of interleukin @ ( il-@ ) , il-@ , tumor necrosis factor ( tnf ) - , and high-sensitivity c-reactive protein ( hscrp ) were measured .',\n",
              " 'there was a clinically relevant reduction in the intervention group compared to the placebo group for knee pain , physical function , pga , and @mwd at @ weeks .',\n",
              " 'the mean difference between treatment arms ( @ % ci ) was @ ( @-@ @ ) , p < @ ; @ ( @-@ @ ) , p < @ ; @ ( @-@ @ ) , p < @ ; and @ ( @-@ @ ) , p < @ , respectively .',\n",
              " 'further , there was a clinically relevant reduction in the serum levels of il-@ , il-@ , tnf - , and hscrp at @ weeks in the intervention group when compared to the placebo group .',\n",
              " 'these differences remained significant at @ weeks .',\n",
              " 'the outcome measures in rheumatology clinical trials-osteoarthritis research society international responder rate was @ % in the intervention group and @ % in the placebo group ( p < @ ) .',\n",
              " 'low-dose oral prednisolone had both a short-term and a longer sustained effect resulting in less knee pain , better physical function , and attenuation of systemic inflammation in older patients with knee oa ( clinicaltrials.gov identifier nct@ ) .',\n",
              " 'emotional eating is associated with overeating and the development of obesity .',\n",
              " 'yet , empirical evidence for individual ( trait ) differences in emotional eating and cognitive mechanisms that contribute to eating during sad mood remain equivocal .',\n",
              " 'the aim of this study was to test if attention bias for food moderates the effect of self-reported emotional eating during sad mood ( vs neutral mood ) on actual food intake .',\n",
              " 'it was expected that emotional eating is predictive of elevated attention for food and higher food intake after an experimentally induced sad mood and that attentional maintenance on food predicts food intake during a sad versus a neutral mood .',\n",
              " 'participants ( n = @ ) were randomly assigned to one of the two experimental mood induction conditions ( sad/neutral ) .',\n",
              " 'attentional biases for high caloric foods were measured by eye tracking during a visual probe task with pictorial food and neutral stimuli .',\n",
              " 'self-reported emotional eating was assessed with the dutch eating behavior questionnaire ( debq ) and ad libitum food intake was tested by a disguised food offer .',\n",
              " 'hierarchical multivariate regression modeling showed that self-reported emotional eating did not account for changes in attention allocation for food or food intake in either condition .']"
            ]
          },
          "metadata": {},
          "execution_count": 17
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Preprocessing for the modelling "
      ],
      "metadata": {
        "id": "_-gGcC0b_HXF"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from sklearn.preprocessing import OneHotEncoder\n",
        "one_hot_encoder = OneHotEncoder(sparse=False)\n",
        "train_labels_one_hot = one_hot_encoder.fit_transform(train_df[\"target\"].to_numpy().reshape(-1, 1))\n",
        "val_labels_one_hot = one_hot_encoder.transform(val_df[\"target\"].to_numpy().reshape(-1, 1))\n",
        "test_labels_one_hot = one_hot_encoder.transform(test_df[\"target\"].to_numpy().reshape(-1, 1))"
      ],
      "metadata": {
        "id": "m1xPlRMc_UCT"
      },
      "execution_count": 18,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "train_labels_one_hot"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "k2yHpXjEA4Ox",
        "outputId": "09492b8c-2a3c-48c3-9ac0-05608633135b"
      },
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array([[0., 0., 0., 1., 0.],\n",
              "       [0., 0., 1., 0., 0.],\n",
              "       [0., 0., 1., 0., 0.],\n",
              "       ...,\n",
              "       [0., 0., 0., 0., 1.],\n",
              "       [0., 1., 0., 0., 0.],\n",
              "       [0., 1., 0., 0., 0.]])"
            ]
          },
          "metadata": {},
          "execution_count": 19
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Labelling\n",
        "from sklearn.preprocessing import LabelEncoder\n",
        "label_encoder = LabelEncoder()\n",
        "train_labels_encoded = label_encoder.fit_transform(train_df[\"target\"].to_numpy())\n",
        "val_labels_encoded = label_encoder.transform(val_df[\"target\"].to_numpy())\n",
        "test_labels_encoded = label_encoder.transform(test_df[\"target\"].to_numpy())"
      ],
      "metadata": {
        "id": "vX3HUoF5BCYE"
      },
      "execution_count": 20,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "train_labels_encoded"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "eLUKrkfeCSkW",
        "outputId": "6578552a-19ee-4060-c28e-e378285f9106"
      },
      "execution_count": 21,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array([3, 2, 2, ..., 4, 1, 1])"
            ]
          },
          "metadata": {},
          "execution_count": 21
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Defining the classes \n",
        "num_classes = len(label_encoder.classes_)\n",
        "class_names = label_encoder.classes_\n",
        "num_classes, class_names"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "iRZkLYj5CWRL",
        "outputId": "6d13fc99-f905-40fa-91ca-36a41f72c9c0"
      },
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(5, array(['BACKGROUND', 'CONCLUSIONS', 'METHODS', 'OBJECTIVE', 'RESULTS'],\n",
              "       dtype=object))"
            ]
          },
          "metadata": {},
          "execution_count": 22
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Modellling"
      ],
      "metadata": {
        "id": "JCz8bulnCtUc"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### model_1"
      ],
      "metadata": {
        "id": "4A0WsFG8CyyW"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# model_1\n",
        "from sklearn.feature_extraction.text import TfidfVectorizer\n",
        "from sklearn.pipeline import Pipeline\n",
        "from sklearn.naive_bayes import MultinomialNB\n",
        "\n",
        "model_1 = Pipeline([\n",
        "                    (\"tf-idf\", TfidfVectorizer()),\n",
        "                    (\"clf\", MultinomialNB())\n",
        "])\n",
        "\n",
        "model_1.fit(X=train_sentences,\n",
        "            y=train_labels_encoded)\n",
        "\n",
        "model_1.score(X=val_sentences,\n",
        "              y=val_labels_encoded)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "M2vpt7MsC3CT",
        "outputId": "96ac5a64-8d1d-42fa-9488-1982ce22f4a8"
      },
      "execution_count": 23,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.7218323844829869"
            ]
          },
          "metadata": {},
          "execution_count": 23
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Prediction on model_1\n",
        "model_1_preds = model_1.predict(val_sentences)\n",
        "model_1_preds"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "n6U5fsL9D98A",
        "outputId": "2dff9fd0-e408-47b9-dffe-537a11bcbaf4"
      },
      "execution_count": 24,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array([4, 1, 3, ..., 4, 4, 1])"
            ]
          },
          "metadata": {},
          "execution_count": 24
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Evaluation \n",
        "model_1_results = calculate_results(y_true=val_labels_encoded,\n",
        "                                  y_pred=model_1_preds)\n",
        "model_1_results"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "w52SI9K4Eiwd",
        "outputId": "19558167-3f45-4522-e067-9ec9691ef5bd"
      },
      "execution_count": 25,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'accuracy': 72.1832384482987,\n",
              " 'f1': 0.6989250353450294,\n",
              " 'precision': 0.7186466952323352,\n",
              " 'recall': 0.7218323844829869}"
            ]
          },
          "metadata": {},
          "execution_count": 25
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## model_2 (with sequencing)"
      ],
      "metadata": {
        "id": "bkMEnDJZF5L2"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# model_2 data preprocessing\n",
        "\n",
        "from tensorflow.keras import layers\n",
        "\n",
        "# average length of sentence\n",
        "sentence_len = [len(sentence.split()) for sentence in train_sentences]\n",
        "avg_sentence_len = np.mean(sentence_len)\n",
        "avg_sentence_len"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Hl8tIxGMGAsw",
        "outputId": "7b4f3a97-4d83-4910-ae48-e5fd6d3bdd52"
      },
      "execution_count": 26,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "26.338269273494777"
            ]
          },
          "metadata": {},
          "execution_count": 26
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "plt.hist(sentence_len, bins=10);"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 265
        },
        "id": "rlcxMwdAGwvA",
        "outputId": "13bfa169-557c-477a-ef68-92d9b70c061c"
      },
      "execution_count": 27,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAD4CAYAAAAZ1BptAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAUAUlEQVR4nO3df6ye5X3f8fdndiA0aTCEI8RsNDuL1cpBXUMs4ipVVMUbGFLVTCKR0TS8zIrVBrp02tSaRRpdEiTYj7IiESpaezFRFMNoKqwG5npAFe0PfhwCAQwlnAIptgCfYn60ixLq9Ls/nsvJs8O5ju3zmPPDvF/So+e+v/d13/d1cR+fD/eP5zmpKiRJms4/mO8OSJIWLkNCktRlSEiSugwJSVKXISFJ6lo63x040c4666xauXLlfHdDkhaVhx9++K+ramxq/aQLiZUrVzI+Pj7f3ZCkRSXJ96ere7lJktRlSEiSugwJSVKXISFJ6jIkJEldhoQkqcuQkCR1GRKSpC5DQpLUddJ94noUK7d9a972/fx1n5y3fUtSj2cSkqQuQ0KS1GVISJK6DAlJUpchIUnqOmpIJNmR5GCSJ4Zq/yXJXyR5LMmfJFk2tOzqJBNJnk5y0VB9Q6tNJNk2VF+V5IFWvy3JKa1+apufaMtXnqhBS5KOzbGcSXwV2DClthc4r6p+AfgecDVAkjXAJuBDbZ2vJFmSZAlwE3AxsAa4vLUFuB64oao+CLwKbGn1LcCrrX5DaydJmkNHDYmq+jZwaErtz6rqcJu9H1jRpjcCu6rqR1X1HDABXNBeE1X1bFW9CewCNiYJ8Angjrb+TuDSoW3tbNN3AOtbe0nSHDkR9yT+NXB3m14OvDC0bH+r9ervB14bCpwj9f9vW2356639WyTZmmQ8yfjk5OTIA5IkDYwUEkm+ABwGvn5iujM7VXVLVa2tqrVjY2/5O96SpFma9ddyJPlXwK8C66uqWvkAcO5QsxWtRqf+CrAsydJ2tjDc/si29idZCpze2kuS5sisziSSbAB+G/i1qvrB0KLdwKb2ZNIqYDXwIPAQsLo9yXQKg5vbu1u43Adc1tbfDNw5tK3Nbfoy4N6hMJIkzYGjnkkk+QbwK8BZSfYD1zB4mulUYG+7l3x/Vf16Ve1LcjvwJIPLUFdW1Y/bdq4C9gBLgB1Vta/t4neAXUm+DDwCbG/17cDXkkwwuHG+6QSMV5J0HI4aElV1+TTl7dPUjrS/Frh2mvpdwF3T1J9l8PTT1PoPgU8drX+SpLePn7iWJHUZEpKkLkNCktRlSEiSugwJSVKXISFJ6jIkJEldhoQkqcuQkCR1GRKSpC5DQpLUZUhIkroMCUlSlyEhSeoyJCRJXYaEJKnLkJAkdRkSkqQuQ0KS1GVISJK6DAlJUpchIUnqMiQkSV2GhCSp66ghkWRHkoNJnhiqnZlkb5Jn2vsZrZ4kNyaZSPJYkvOH1tnc2j+TZPNQ/SNJHm/r3JgkM+1DkjR3juVM4qvAhim1bcA9VbUauKfNA1wMrG6vrcDNMPiFD1wDfBS4ALhm6Jf+zcBnh9bbcJR9SJLmyFFDoqq+DRyaUt4I7GzTO4FLh+q31sD9wLIk5wAXAXur6lBVvQrsBTa0Ze+rqvurqoBbp2xrun1IkubIbO9JnF1VL7bpl4Cz2/Ry4IWhdvtbbab6/mnqM+3jLZJsTTKeZHxycnIWw5EkTWfkG9ftDKBOQF9mvY+quqWq1lbV2rGxsbezK5L0jjLbkHi5XSqivR9s9QPAuUPtVrTaTPUV09Rn2ockaY7MNiR2A0eeUNoM3DlUv6I95bQOeL1dMtoDXJjkjHbD+kJgT1v2RpJ17ammK6Zsa7p9SJLmyNKjNUjyDeBXgLOS7GfwlNJ1wO1JtgDfBz7dmt8FXAJMAD8APgNQVYeSfAl4qLX7YlUduRn+OQZPUJ0G3N1ezLAPSdIcOWpIVNXlnUXrp2lbwJWd7ewAdkxTHwfOm6b+ynT7kCTNHT9xLUnqMiQkSV2GhCSpy5CQJHUZEpKkLkNCktRlSEiSugwJSVKXISFJ6jIkJEldhoQkqcuQkCR1GRKSpC5DQpLUZUhIkroMCUlSlyEhSeoyJCRJXYaEJKnLkJAkdRkSkqQuQ0KS1GVISJK6DAlJUtdIIZHk3ybZl+SJJN9I8u4kq5I8kGQiyW1JTmltT23zE235yqHtXN3qTye5aKi+odUmkmwbpa+SpOM365BIshz4N8DaqjoPWAJsAq4HbqiqDwKvAlvaKluAV1v9htaOJGvaeh8CNgBfSbIkyRLgJuBiYA1weWsrSZojo15uWgqclmQp8DPAi8AngDva8p3ApW16Y5unLV+fJK2+q6p+VFXPARPABe01UVXPVtWbwK7WVpI0R2YdElV1APivwF8xCIfXgYeB16rqcGu2H1jeppcDL7R1D7f27x+uT1mnV3+LJFuTjCcZn5ycnO2QJElTjHK56QwG/2e/CviHwHsYXC6ac1V1S1Wtraq1Y2Nj89EFSTopjXK56Z8Cz1XVZFX9HfBN4GPAsnb5CWAFcKBNHwDOBWjLTwdeGa5PWadXlyTNkVFC4q+AdUl+pt1bWA88CdwHXNbabAbubNO72zxt+b1VVa2+qT39tApYDTwIPASsbk9LncLg5vbuEforSTpOS4/eZHpV9UCSO4DvAIeBR4BbgG8Bu5J8udW2t1W2A19LMgEcYvBLn6ral+R2BgFzGLiyqn4MkOQqYA+DJ6d2VNW+2fZXknT8Zh0SAFV1DXDNlPKzDJ5Mmtr2h8CnOtu5Frh2mvpdwF2j9FGSNHt+4lqS1GVISJK6DAlJUpchIUnqMiQkSV2GhCSpy5CQJHUZEpKkLkNCktRlSEiSugwJSVKXISFJ6jIkJEldhoQkqcuQkCR1GRKSpC5DQpLUZUhIkroMCUlSlyEhSeoyJCRJXYaEJKnLkJAkdRkSkqSukUIiybIkdyT5iyRPJfmlJGcm2ZvkmfZ+RmubJDcmmUjyWJLzh7azubV/JsnmofpHkjze1rkxSUbpryTp+Ix6JvH7wP+qqp8H/gnwFLANuKeqVgP3tHmAi4HV7bUVuBkgyZnANcBHgQuAa44ES2vz2aH1NozYX0nScZh1SCQ5Hfg4sB2gqt6sqteAjcDO1mwncGmb3gjcWgP3A8uSnANcBOytqkNV9SqwF9jQlr2vqu6vqgJuHdqWJGkOjHImsQqYBP5HkkeS/FGS9wBnV9WLrc1LwNltejnwwtD6+1ttpvr+aepvkWRrkvEk45OTkyMMSZI0bJSQWAqcD9xcVR8G/i8/vbQEQDsDqBH2cUyq6paqWltVa8fGxt7u3UnSO8YoIbEf2F9VD7T5OxiExsvtUhHt/WBbfgA4d2j9Fa02U33FNHVJ0hyZdUhU1UvAC0l+rpXWA08Cu4EjTyhtBu5s07uBK9pTTuuA19tlqT3AhUnOaDesLwT2tGVvJFnXnmq6YmhbkqQ5sHTE9X8T+HqSU4Bngc8wCJ7bk2wBvg98urW9C7gEmAB+0NpSVYeSfAl4qLX7YlUdatOfA74KnAbc3V6SpDkyUkhU1aPA2mkWrZ+mbQFXdrazA9gxTX0cOG+UPkqSZs9PXEuSugwJSVKXISFJ6jIkJEldhoQkqcuQkCR1GRKSpC5DQpLUZUhIkroMCUlSlyEhSeoyJCRJXYaEJKnLkJAkdY369yR0gqzc9q152e/z131yXvYraXHwTEKS1GVISJK6DAlJUpchIUnqMiQkSV2GhCSpy5CQJHUZEpKkLkNCktQ1ckgkWZLkkSR/2uZXJXkgyUSS25Kc0uqntvmJtnzl0DaubvWnk1w0VN/QahNJto3aV0nS8TkRZxKfB54amr8euKGqPgi8Cmxp9S3Aq61+Q2tHkjXAJuBDwAbgKy14lgA3ARcDa4DLW1tJ0hwZKSSSrAA+CfxRmw/wCeCO1mQncGmb3tjmacvXt/YbgV1V9aOqeg6YAC5or4mqeraq3gR2tbaSpDky6pnEfwd+G/j7Nv9+4LWqOtzm9wPL2/Ry4AWAtvz11v4n9Snr9OpvkWRrkvEk45OTkyMOSZJ0xKxDIsmvAger6uET2J9ZqapbqmptVa0dGxub7+5I0kljlK8K/xjwa0kuAd4NvA/4fWBZkqXtbGEFcKC1PwCcC+xPshQ4HXhlqH7E8Dq9uiRpDsz6TKKqrq6qFVW1ksGN53ur6l8A9wGXtWabgTvb9O42T1t+b1VVq29qTz+tAlYDDwIPAavb01KntH3snm1/JUnH7+34o0O/A+xK8mXgEWB7q28HvpZkAjjE4Jc+VbUvye3Ak8Bh4Mqq+jFAkquAPcASYEdV7Xsb+itJ6jghIVFVfw78eZt+lsGTSVPb/BD4VGf9a4Frp6nfBdx1IvooSTp+fuJaktRlSEiSugwJSVKXISFJ6jIkJEldhoQkqcuQkCR1GRKSpC5DQpLUZUhIkroMCUlSlyEhSeoyJCRJXYaEJKnLkJAkdRkSkqQuQ0KS1GVISJK6DAlJUpchIUnqMiQkSV2GhCSpy5CQJHUZEpKkrlmHRJJzk9yX5Mkk+5J8vtXPTLI3yTPt/YxWT5Ibk0wkeSzJ+UPb2tzaP5Nk81D9I0keb+vcmCSjDFaSdHxGOZM4DPy7qloDrAOuTLIG2AbcU1WrgXvaPMDFwOr22grcDINQAa4BPgpcAFxzJFham88OrbdhhP5Kko7TrEOiql6squ+06b8BngKWAxuBna3ZTuDSNr0RuLUG7geWJTkHuAjYW1WHqupVYC+woS17X1XdX1UF3Dq0LUnSHDgh9ySSrAQ+DDwAnF1VL7ZFLwFnt+nlwAtDq+1vtZnq+6epT7f/rUnGk4xPTk6ONBZJ0k+NHBJJ3gv8MfBbVfXG8LJ2BlCj7uNoquqWqlpbVWvHxsbe7t1J0jvGSCGR5F0MAuLrVfXNVn65XSqivR9s9QPAuUOrr2i1meorpqlLkubIKE83BdgOPFVVvze0aDdw5AmlzcCdQ/Ur2lNO64DX22WpPcCFSc5oN6wvBPa0ZW8kWdf2dcXQtiRJc2DpCOt+DPiXwONJHm21/wBcB9yeZAvwfeDTbdldwCXABPAD4DMAVXUoyZeAh1q7L1bVoTb9OeCrwGnA3e0lSZojsw6Jqvo/QO9zC+unaV/AlZ1t7QB2TFMfB86bbR8lSaPxE9eSpC5DQpLUZUhIkroMCUlSlyEhSeoyJCRJXYaEJKnLkJAkdRkSkqQuQ0KS1GVISJK6DAlJUpchIUnqMiQkSV2GhCSpy5CQJHUZEpKkLkNCktQ1yt+41klg5bZvzct+n7/uk/OyX0nHxzMJSVKXISFJ6jIkJEldhoQkqcuQkCR1LfiQSLIhydNJJpJsm+/+SNI7yYIOiSRLgJuAi4E1wOVJ1sxvryTpnWOhf07iAmCiqp4FSLIL2Ag8Oa+90sjm6/MZ4Gc0pOOx0ENiOfDC0Px+4KNTGyXZCmxts3+b5OlZ7Oss4K9nsd5CdTKN54SOJdefqC3Nmsdm4Xonj+cfTVdc6CFxTKrqFuCWUbaRZLyq1p6gLs27k2k8J9NY4OQaz8k0FnA801nQ9ySAA8C5Q/MrWk2SNAcWekg8BKxOsirJKcAmYPc890mS3jEW9OWmqjqc5CpgD7AE2FFV+96m3Y10uWoBOpnGczKNBU6u8ZxMYwHH8xapqhPREUnSSWihX26SJM0jQ0KS1GVIsPi/+iPJ80keT/JokvFWOzPJ3iTPtPcz5rufPUl2JDmY5Imh2rT9z8CN7Vg9luT8+ev59Drj+d0kB9oxejTJJUPLrm7jeTrJRfPT6+klOTfJfUmeTLIvyedbfdEdnxnGsliPzbuTPJjku208/6nVVyV5oPX7tvbQD0lObfMTbfnKY9pRVb2jXwxuiP8l8AHgFOC7wJr57tdxjuF54Kwptf8MbGvT24Dr57ufM/T/48D5wBNH6z9wCXA3EGAd8MB89/8Yx/O7wL+fpu2a9jN3KrCq/Swume8xDPXvHOD8Nv2zwPdanxfd8ZlhLIv12AR4b5t+F/BA+29+O7Cp1f8A+I02/TngD9r0JuC2Y9mPZxJDX/1RVW8CR776Y7HbCOxs0zuBS+exLzOqqm8Dh6aUe/3fCNxaA/cDy5KcMzc9PTad8fRsBHZV1Y+q6jlggsHP5IJQVS9W1Xfa9N8ATzH4JoRFd3xmGEvPQj82VVV/22bf1V4FfAK4o9WnHpsjx+wOYH2SHG0/hsT0X/0x0w/OQlTAnyV5uH1FCcDZVfVim34JOHt+ujZrvf4v5uN1VbsEs2Po8t+iGU+7PPFhBv/HuqiPz5SxwCI9NkmWJHkUOAjsZXC281pVHW5Nhvv8k/G05a8D7z/aPgyJk8MvV9X5DL4t98okHx9eWIPzy0X7rPNi739zM/CPgV8EXgT+2/x25/gkeS/wx8BvVdUbw8sW2/GZZiyL9thU1Y+r6hcZfBvFBcDPn+h9GBInwVd/VNWB9n4Q+BMGPywvHznNb+8H56+Hs9Lr/6I8XlX1cvsH/ffAH/LTyxYLfjxJ3sXgl+rXq+qbrbwoj890Y1nMx+aIqnoNuA/4JQaX+I58UHq4zz8ZT1t+OvDK0bZtSCzyr/5I8p4kP3tkGrgQeILBGDa3ZpuBO+enh7PW6/9u4Ir2FM064PWhyx4L1pTr8v+cwTGCwXg2tSdPVgGrgQfnun897Zr1duCpqvq9oUWL7vj0xrKIj81YkmVt+jTgnzG4z3IfcFlrNvXYHDlmlwH3trPAmc33HfqF8GLwRMb3GFzP+8J89+c4+/4BBk9gfBfYd6T/DK413gM8A/xv4Mz57usMY/gGg9P8v2NwDXVLr/8Mnui4qR2rx4G1893/YxzP11p/H2v/WM8Zav+FNp6ngYvnu/9TxvLLDC4lPQY82l6XLMbjM8NYFuux+QXgkdbvJ4D/2OofYBBmE8D/BE5t9Xe3+Ym2/APHsh+/lkOS1OXlJklSlyEhSeoyJCRJXYaEJKnLkJAkdRkSkqQuQ0KS1PX/AAAFpXNjsjAuAAAAAElFTkSuQmCC\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# check the % of sentences length between 0-50 & max sentence length\n",
        "output_seq_len = int(np.percentile(sentence_len, 95))\n",
        "output_seq_len, max(sentence_len)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_vY1_UbTG6c9",
        "outputId": "a173adfe-2695-4856-ad34-35603f578725"
      },
      "execution_count": 28,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(55, 296)"
            ]
          },
          "metadata": {},
          "execution_count": 28
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### Vectorization"
      ],
      "metadata": {
        "id": "GvtJEpFQHx4T"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Section 3.2 states that the vocabulary size is 68,000 - https://arxiv.org/pdf/1710.06071.pdf\n",
        "max_tokens = 68000 "
      ],
      "metadata": {
        "id": "flq878KaH8sa"
      },
      "execution_count": 29,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# text vectorization\n",
        "from tensorflow.keras.layers.experimental.preprocessing import TextVectorization\n",
        "text_vectorizer = TextVectorization(max_tokens=max_tokens,\n",
        "                                    output_sequence_length=55)\n",
        "# adapting text vectorizer to training sentences\n",
        "text_vectorizer.adapt(train_sentences)\n",
        "\n",
        "# testing the text vectorizer\n",
        "import random \n",
        "target_sentence = random.choice(train_sentences)\n",
        "print(f\"Text:\\n{target_sentence}\")\n",
        "print(f\"\\nLength of text: {len(target_sentence.split())}\")\n",
        "print(f\"\\nVectorized text:\\n{text_vectorizer([target_sentence])}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Tm2-BY0yQKZF",
        "outputId": "74e73b10-c1d0-43fc-ef30-a4fbdaa13b8f"
      },
      "execution_count": 30,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Text:\n",
            "the authors conducted a randomized crossover trial in one internal medicine and one family medicine residency program between january @ and july @ .\n",
            "\n",
            "Length of text: 24\n",
            "\n",
            "Vectorized text:\n",
            "[[   2 1473  198    8   29  484   32    5   88 1333  941    3   88  791\n",
            "   941 8376  256   30 1174    3 1570    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0]]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Number of words in the training vocabulary\n",
        "rct_20k_text_vocab = text_vectorizer.get_vocabulary()\n",
        "print(f\"Total number of words in vocabulary: {len(rct_20k_text_vocab)}\"),\n",
        "print(f\"Most common words: {rct_20k_text_vocab[:5]}\")\n",
        "print(f\"Least common words: {rct_20k_text_vocab[-5:]}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3uFCGlvbR9OG",
        "outputId": "5edfaccf-e07f-4763-c9c9-1d342d07671b"
      },
      "execution_count": 31,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Total number of words in vocabulary: 64841\n",
            "Most common words: ['', '[UNK]', 'the', 'and', 'of']\n",
            "Least common words: ['aainduced', 'aaigroup', 'aachener', 'aachen', 'aaacp']\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Configuration \n",
        "text_vectorizer.get_config()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "lYF2APNjTLs2",
        "outputId": "9ad3c255-b947-4399-c88a-983513fb4353"
      },
      "execution_count": 32,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'batch_input_shape': (None,),\n",
              " 'dtype': 'string',\n",
              " 'idf_weights': None,\n",
              " 'max_tokens': 68000,\n",
              " 'name': 'text_vectorization',\n",
              " 'ngrams': None,\n",
              " 'output_mode': 'int',\n",
              " 'output_sequence_length': 55,\n",
              " 'pad_to_max_tokens': False,\n",
              " 'ragged': False,\n",
              " 'sparse': False,\n",
              " 'split': 'whitespace',\n",
              " 'standardize': 'lower_and_strip_punctuation',\n",
              " 'trainable': True,\n",
              " 'vocabulary': None}"
            ]
          },
          "metadata": {},
          "execution_count": 32
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Embedding token layer\n",
        "token_embed = layers.Embedding(input_dim=len(rct_20k_text_vocab),\n",
        "                              output_dim=128,\n",
        "                              mask_zero=True,\n",
        "                              name=\"token_embedding\")\n",
        "\n",
        "# View embedded token layer\n",
        "print(f\"before vectorization:\\n{target_sentence}\\n\")\n",
        "vectorized__sentence = text_vectorizer([target_sentence])\n",
        "print(f\"after vectorization:\\n{vectorized__sentence}\\n\")\n",
        "embedded_sentence = token_embed(vectorized__sentence)\n",
        "print(f\"after embedding:\\n{embedded_sentence}\\n\")\n",
        "print(f\"shape after embedding: {embedded_sentence.shape}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "LAQmPECiTVH9",
        "outputId": "2a4fa514-9bee-427b-a111-f174fd6452b9"
      },
      "execution_count": 33,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "before vectorization:\n",
            "the authors conducted a randomized crossover trial in one internal medicine and one family medicine residency program between january @ and july @ .\n",
            "\n",
            "after vectorization:\n",
            "[[   2 1473  198    8   29  484   32    5   88 1333  941    3   88  791\n",
            "   941 8376  256   30 1174    3 1570    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0]]\n",
            "\n",
            "after embedding:\n",
            "[[[-0.02017461 -0.03074959  0.00926958 ... -0.02336853  0.01859171\n",
            "    0.00190452]\n",
            "  [-0.04895055 -0.01997001  0.04958674 ... -0.00581187  0.01886975\n",
            "   -0.03286358]\n",
            "  [ 0.03682723  0.02068727  0.01939244 ...  0.03774397  0.0362787\n",
            "    0.04843071]\n",
            "  ...\n",
            "  [ 0.02833715 -0.00203717  0.03626789 ... -0.0118333   0.00203039\n",
            "    0.00635558]\n",
            "  [ 0.02833715 -0.00203717  0.03626789 ... -0.0118333   0.00203039\n",
            "    0.00635558]\n",
            "  [ 0.02833715 -0.00203717  0.03626789 ... -0.0118333   0.00203039\n",
            "    0.00635558]]]\n",
            "\n",
            "shape after embedding: (1, 55, 128)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Using Tensorflow dataset API for fast processing\n",
        "train_dataset = tf.data.Dataset.from_tensor_slices((train_sentences, train_labels_one_hot))\n",
        "valid_dataset = tf.data.Dataset.from_tensor_slices((val_sentences, val_labels_one_hot))\n",
        "test_dataset = tf.data.Dataset.from_tensor_slices((test_sentences, test_labels_one_hot))\n",
        "train_dataset"
      ],
      "metadata": {
        "id": "n0kzy2EKVJ2_",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "9c0d33c2-4c35-40a8-d5d4-de501821125e"
      },
      "execution_count": 34,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<TensorSliceDataset shapes: ((), (5,)), types: (tf.string, tf.float64)>"
            ]
          },
          "metadata": {},
          "execution_count": 34
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Convert the data into batches\n",
        "train_dataset = train_dataset.batch(32).prefetch(tf.data.AUTOTUNE)\n",
        "valid_dataset = valid_dataset.batch(32).prefetch(tf.data.AUTOTUNE)\n",
        "test_dataset = test_dataset.batch(32).prefetch(tf.data.AUTOTUNE)\n",
        "train_dataset"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "HNe8Bg9YI1jN",
        "outputId": "22b11043-9ace-40b3-d629-7590f51504ae"
      },
      "execution_count": 35,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<PrefetchDataset shapes: ((None,), (None, 5)), types: (tf.string, tf.float64)>"
            ]
          },
          "metadata": {},
          "execution_count": 35
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Modelling Conv1D model_2\n",
        "\n",
        "inputs= layers.Input(shape=(1,), dtype=tf.string)\n",
        "text_vectors = text_vectorizer(inputs)\n",
        "token_embeddings = token_embed(text_vectors)\n",
        "x = layers.Conv1D(64, kernel_size=5, padding=\"same\", activation=\"relu\")(token_embeddings)\n",
        "x = layers.GlobalAveragePooling1D()(x)\n",
        "outputs = layers.Dense(num_classes, activation=\"softmax\")(x)\n",
        "model_2 = tf.keras.Model(inputs, outputs)\n",
        "\n",
        "model_2.compile(loss=\"categorical_crossentropy\",\n",
        "                optimizer=tf.keras.optimizers.Adam(),\n",
        "                metrics=[\"accuracy\"])\n",
        "model_2_history = model_2.fit(train_dataset,\n",
        "                              steps_per_epoch=int(0.1 * len(train_dataset)),\n",
        "                              epochs=3,\n",
        "                              validation_data=valid_dataset,\n",
        "                              validation_steps=int(0.1 * len(valid_dataset)))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "rg0oLAf4JdWn",
        "outputId": "f3380fd9-6d93-40c3-f35e-44238d665ea6"
      },
      "execution_count": 36,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch 1/3\n",
            "562/562 [==============================] - 58s 101ms/step - loss: 0.9061 - accuracy: 0.6422 - val_loss: 0.6805 - val_accuracy: 0.7394\n",
            "Epoch 2/3\n",
            "562/562 [==============================] - 57s 101ms/step - loss: 0.6568 - accuracy: 0.7568 - val_loss: 0.6270 - val_accuracy: 0.7729\n",
            "Epoch 3/3\n",
            "562/562 [==============================] - 56s 100ms/step - loss: 0.6181 - accuracy: 0.7766 - val_loss: 0.5951 - val_accuracy: 0.7856\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "model_2.summary()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "c6PN_DzCLIjd",
        "outputId": "22fa07ec-231b-402f-c2ca-ce4b22cb17f4"
      },
      "execution_count": 37,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Model: \"model\"\n",
            "_________________________________________________________________\n",
            " Layer (type)                Output Shape              Param #   \n",
            "=================================================================\n",
            " input_1 (InputLayer)        [(None, 1)]               0         \n",
            "                                                                 \n",
            " text_vectorization (TextVec  (None, 55)               0         \n",
            " torization)                                                     \n",
            "                                                                 \n",
            " token_embedding (Embedding)  (None, 55, 128)          8299648   \n",
            "                                                                 \n",
            " conv1d (Conv1D)             (None, 55, 64)            41024     \n",
            "                                                                 \n",
            " global_average_pooling1d (G  (None, 64)               0         \n",
            " lobalAveragePooling1D)                                          \n",
            "                                                                 \n",
            " dense (Dense)               (None, 5)                 325       \n",
            "                                                                 \n",
            "=================================================================\n",
            "Total params: 8,340,997\n",
            "Trainable params: 8,340,997\n",
            "Non-trainable params: 0\n",
            "_________________________________________________________________\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Evaluation\n",
        "model_2.evaluate(valid_dataset)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fmQ7FzkVMCMh",
        "outputId": "9d22d571-fe8c-4563-9876-a74980eba76c"
      },
      "execution_count": 38,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "945/945 [==============================] - 5s 6ms/step - loss: 0.5989 - accuracy: 0.7868\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[0.598865807056427, 0.7867734432220459]"
            ]
          },
          "metadata": {},
          "execution_count": 38
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Predictions\n",
        "model_2_pred_probs = model_2.predict(valid_dataset)\n",
        "model_2_pred_probs"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2gnFLbCYMObh",
        "outputId": "b42ce495-d5ca-405c-ec78-2a5b65caa627"
      },
      "execution_count": 39,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array([[4.40852791e-01, 1.58087149e-01, 1.08493336e-01, 2.59328723e-01,\n",
              "        3.32380310e-02],\n",
              "       [4.07725662e-01, 2.96386003e-01, 1.62761211e-02, 2.70840019e-01,\n",
              "        8.77214782e-03],\n",
              "       [1.41440153e-01, 9.46754497e-03, 1.91584730e-03, 8.47138643e-01,\n",
              "        3.78725817e-05],\n",
              "       ...,\n",
              "       [3.81689779e-06, 5.15244086e-04, 5.68084070e-04, 1.65876520e-06,\n",
              "        9.98911142e-01],\n",
              "       [5.77327423e-02, 4.30542588e-01, 1.11858152e-01, 7.07353279e-02,\n",
              "        3.29131275e-01],\n",
              "       [1.59427017e-01, 6.66728795e-01, 3.24941985e-02, 5.08395135e-02,\n",
              "        9.05105695e-02]], dtype=float32)"
            ]
          },
          "metadata": {},
          "execution_count": 39
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Turning the prediction probabilities into classes\n",
        "model_2_preds = tf.argmax(model_2_pred_probs, axis=1)\n",
        "model_2_preds"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "awfCkxO6MdCx",
        "outputId": "68b26bb0-fd9b-410a-c7dc-aab38a9dd386"
      },
      "execution_count": 40,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<tf.Tensor: shape=(30212,), dtype=int64, numpy=array([0, 0, 3, ..., 4, 1, 1])>"
            ]
          },
          "metadata": {},
          "execution_count": 40
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Evaluating the results \n",
        "model_2_results = calculate_results(y_true=val_labels_encoded,\n",
        "                                    y_pred=model_2_preds)\n",
        "model_2_results"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "u24eJPbeM0d9",
        "outputId": "172ffac9-27a1-451b-963f-637b24aeb9dc"
      },
      "execution_count": 41,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'accuracy': 78.6773467496359,\n",
              " 'f1': 0.7843933683333598,\n",
              " 'precision': 0.7834033897818605,\n",
              " 'recall': 0.786773467496359}"
            ]
          },
          "metadata": {},
          "execution_count": 41
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## model_3 using feature extraction"
      ],
      "metadata": {
        "id": "FrNvx0_ONNlP"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Pretrained model universal-sentence-encoder from hub\n",
        "import tensorflow_hub as hub\n",
        "tf_hub_embedding_layer = hub.KerasLayer(\"https://tfhub.dev/google/universal-sentence-encoder/4\",\n",
        "                                        trainable=False,\n",
        "                                        name=\"universal_sentence_encoder\")\n",
        "\n",
        "# Testing \n",
        "sample_sentence_for_training = random.choice(train_sentences)\n",
        "print(f\"Sample training sentence:\\n{sample_sentence_for_training}\\n\")\n",
        "use_embedded_sentence = tf_hub_embedding_layer([sample_sentence_for_training])\n",
        "print(f\"Sample sentence after embedding:\\n{use_embedded_sentence[0][:30]} (truncated_output)...\\n\")\n",
        "print(f\"Length of the embedded sentence:\\n{len(use_embedded_sentence[0])}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4CUqA1AMNYdS",
        "outputId": "743428f5-9c44-4fd3-b419-4efdfb351573"
      },
      "execution_count": 42,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Sample training sentence:\n",
            "there was no significant difference in gender , age , disease duration , side , and affected segment between @ groups ( p > @ ) .\n",
            "\n",
            "Sample sentence after embedding:\n",
            "[-0.01892477 -0.00388603 -0.00958671  0.04508729  0.02372267 -0.00646324\n",
            "  0.01541883 -0.04411809  0.02103367  0.07683356  0.08414494  0.04277237\n",
            " -0.01566472  0.05267743  0.01859346 -0.0111114  -0.08790275  0.01383738\n",
            "  0.00347    -0.05370739  0.00376647  0.06776437 -0.04726404 -0.00242291\n",
            " -0.03000989 -0.00198491  0.00077366  0.04609337 -0.01708829 -0.02211344] (truncated_output)...\n",
            "\n",
            "Length of the embedded sentence:\n",
            "512\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Modelling model_3\n",
        "inputs = layers.Input(shape=[], dtype=tf.string)\n",
        "pretrained_embedding = tf_hub_embedding_layer(inputs)\n",
        "x = layers.Dense(128, activation=\"relu\")(pretrained_embedding)\n",
        "outputs = layers.Dense(5, activation=\"softmax\")(x)\n",
        "model_3 = tf.keras.Model(inputs=inputs,\n",
        "                         outputs=outputs)\n",
        "model_3.compile(loss=\"categorical_crossentropy\",\n",
        "                optimizer=tf.keras.optimizers.Adam(),\n",
        "                metrics=[\"accuracy\"])\n",
        "model_3.fit(train_dataset,\n",
        "            epochs=3,\n",
        "            steps_per_epoch=int(0.1 * len(train_dataset)),\n",
        "            validation_data=valid_dataset,\n",
        "            validation_steps=int(0.1 * len(valid_dataset)))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "A0auhibLPYdI",
        "outputId": "1324d5b2-8f38-4932-9a39-1959b0f3747f"
      },
      "execution_count": 43,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch 1/3\n",
            "562/562 [==============================] - 10s 14ms/step - loss: 0.9173 - accuracy: 0.6494 - val_loss: 0.7942 - val_accuracy: 0.6895\n",
            "Epoch 2/3\n",
            "562/562 [==============================] - 7s 13ms/step - loss: 0.7682 - accuracy: 0.7022 - val_loss: 0.7527 - val_accuracy: 0.7068\n",
            "Epoch 3/3\n",
            "562/562 [==============================] - 7s 13ms/step - loss: 0.7523 - accuracy: 0.7138 - val_loss: 0.7385 - val_accuracy: 0.7128\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<keras.callbacks.History at 0x7f1d8facd110>"
            ]
          },
          "metadata": {},
          "execution_count": 43
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "model_3.summary()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "tRGNsCb6R_CQ",
        "outputId": "2e040f58-09d2-4e70-83ac-a1a78c5415ae"
      },
      "execution_count": 44,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Model: \"model_1\"\n",
            "_________________________________________________________________\n",
            " Layer (type)                Output Shape              Param #   \n",
            "=================================================================\n",
            " input_2 (InputLayer)        [(None,)]                 0         \n",
            "                                                                 \n",
            " universal_sentence_encoder   (None, 512)              256797824 \n",
            " (KerasLayer)                                                    \n",
            "                                                                 \n",
            " dense_1 (Dense)             (None, 128)               65664     \n",
            "                                                                 \n",
            " dense_2 (Dense)             (None, 5)                 645       \n",
            "                                                                 \n",
            "=================================================================\n",
            "Total params: 256,864,133\n",
            "Trainable params: 66,309\n",
            "Non-trainable params: 256,797,824\n",
            "_________________________________________________________________\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Evaluate model_3\n",
        "model_3.evaluate(valid_dataset)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "waZmi-ynSKHi",
        "outputId": "55f4d352-07e4-40cb-f344-de8a22332507"
      },
      "execution_count": 45,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "945/945 [==============================] - 11s 11ms/step - loss: 0.7419 - accuracy: 0.7144\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[0.7419232130050659, 0.7143518924713135]"
            ]
          },
          "metadata": {},
          "execution_count": 45
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Predictions \n",
        "model_3_pred_probs = model_3.predict(valid_dataset)\n",
        "model_3_pred_probs"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "sX4jowPMSVnL",
        "outputId": "33592219-02ca-4f51-d77a-faa3ca82e3b4"
      },
      "execution_count": 46,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array([[4.1173980e-01, 3.6531892e-01, 2.2412178e-03, 2.1321088e-01,\n",
              "        7.4891304e-03],\n",
              "       [3.2291305e-01, 5.3647935e-01, 3.0751117e-03, 1.3476585e-01,\n",
              "        2.7666835e-03],\n",
              "       [2.5439259e-01, 1.6618186e-01, 1.7330565e-02, 5.1821136e-01,\n",
              "        4.3883622e-02],\n",
              "       ...,\n",
              "       [1.4093145e-03, 5.4870546e-03, 5.6798019e-02, 7.7662029e-04,\n",
              "        9.3552899e-01],\n",
              "       [3.7208851e-03, 5.5228781e-02, 1.8772416e-01, 1.3641112e-03,\n",
              "        7.5196213e-01],\n",
              "       [1.9264714e-01, 2.7150142e-01, 4.7964740e-01, 8.6932555e-03,\n",
              "        4.7510855e-02]], dtype=float32)"
            ]
          },
          "metadata": {},
          "execution_count": 46
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Convert the pred_probs to classes\n",
        "model_3_preds = tf.argmax(model_3_pred_probs, axis=1)\n",
        "model_3_preds"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ZsJDtbB5SiWa",
        "outputId": "dc48c730-adfa-453d-8b99-19d4800263af"
      },
      "execution_count": 47,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<tf.Tensor: shape=(30212,), dtype=int64, numpy=array([0, 1, 3, ..., 4, 4, 2])>"
            ]
          },
          "metadata": {},
          "execution_count": 47
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Evaluating the results\n",
        "model_3_results = calculate_results(y_true=val_labels_encoded,\n",
        "                                    y_pred=model_3_preds)\n",
        "model_3_results"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "KsqM_CmUSxXz",
        "outputId": "86cc4ff4-674b-4b25-8ce4-646d9590b139"
      },
      "execution_count": 48,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'accuracy': 71.43519131470939,\n",
              " 'f1': 0.7112630554846233,\n",
              " 'precision': 0.7148152528054331,\n",
              " 'recall': 0.7143519131470939}"
            ]
          },
          "metadata": {},
          "execution_count": 48
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# model_4 with character embeddings"
      ],
      "metadata": {
        "id": "1gmVF8QeTBkm"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Function for character splitting\n",
        "def split_character(text):\n",
        "  return \" \".join(list(text))\n",
        "\n",
        "split_character(sample_sentence_for_training)  "
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 53
        },
        "id": "lCDKdUmZTPFu",
        "outputId": "a4524a87-93a0-43e1-d2b8-878abcd40310"
      },
      "execution_count": 49,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'t h e r e   w a s   n o   s i g n i f i c a n t   d i f f e r e n c e   i n   g e n d e r   ,   a g e   ,   d i s e a s e   d u r a t i o n   ,   s i d e   ,   a n d   a f f e c t e d   s e g m e n t   b e t w e e n   @   g r o u p s   (   p   >   @   )   .'"
            ]
          },
          "metadata": {},
          "execution_count": 49
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# split the sequence \n",
        "train_char = [split_character(sentence) for sentence in train_sentences]\n",
        "val_char = [split_character(sentence) for sentence in val_sentences]\n",
        "test_char = [split_character(sentence) for sentence in test_sentences]\n",
        "print(train_char[0])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Wz0MVrtRTqWd",
        "outputId": "d92666b0-81fb-4852-d656-2e4a7bdfabb0"
      },
      "execution_count": 50,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "t o   i n v e s t i g a t e   t h e   e f f i c a c y   o f   @   w e e k s   o f   d a i l y   l o w - d o s e   o r a l   p r e d n i s o l o n e   i n   i m p r o v i n g   p a i n   ,   m o b i l i t y   ,   a n d   s y s t e m i c   l o w - g r a d e   i n f l a m m a t i o n   i n   t h e   s h o r t   t e r m   a n d   w h e t h e r   t h e   e f f e c t   w o u l d   b e   s u s t a i n e d   a t   @   w e e k s   i n   o l d e r   a d u l t s   w i t h   m o d e r a t e   t o   s e v e r e   k n e e   o s t e o a r t h r i t i s   (   o a   )   .\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Average length of the character\n",
        "character_length = [len(sentence) for sentence in train_sentences]\n",
        "mean_character_length = np.mean(character_length)\n",
        "mean_character_length"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "74WL9rjmUYuk",
        "outputId": "d9f5a840-f7b4-426f-c7d5-bd59473f62a0"
      },
      "execution_count": 51,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "149.3662574983337"
            ]
          },
          "metadata": {},
          "execution_count": 51
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "plt.hist(character_length, bins=10);"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 265
        },
        "id": "Q-NFspvNUxzJ",
        "outputId": "6d12697a-ecc8-42f5-dc85-48636d5a8508"
      },
      "execution_count": 52,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAD4CAYAAAAO9oqkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAQ60lEQVR4nO3df6zddX3H8edrreWXkxbpGLbNbp2NSzVxYIMlLGYRBwWMZQmaEjOqYzaZuKkzcUWTkakksBkREn8RqIJhFlbZaADXMMA/9geViyhQsHLlh7QBuVp+bBp/VN/743wuHMu9vefCveec0ucjObnf7/v7+Z7zPp/ce173+z3fc2+qCknSwe33Bt2AJGnwDANJkmEgSTIMJEkYBpIkYP6gG3ixjj766BoZGRl0G5J0wLjrrrt+UlWLJ9t2wIbByMgIo6Ojg25Dkg4YSR6dapuniSRJhoEkyTCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CSxAH8CeSXYmTjTQN53EcuOmMgjytJ0/HIQJJkGEiSDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkugxDJJ8JMmOJPcl+XqSQ5MsT7I9yViSa5MsaGMPaetjbftI1/2c3+o7k5zaVV/TamNJNs72k5Qk7d+0YZBkCfD3wKqqeiMwD1gHXAxcUlWvA54Czm27nAs81eqXtHEkWdn2ewOwBvhCknlJ5gGfB04DVgJnt7GSpD7p9TTRfOCwJPOBw4HHgbcBW9r2q4Az2/Latk7bfnKStPrmqvplVT0MjAEntNtYVT1UVb8CNrexkqQ+mTYMqmo38BngR3RC4BngLuDpqtrbhu0ClrTlJcBjbd+9bfyru+v77DNV/QWSbEgymmR0fHy8l+cnSepBL6eJFtH5TX058BrgCDqnefquqi6vqlVVtWrx4sWDaEGSXpZ6OU30duDhqhqvql8D1wMnAQvbaSOApcDutrwbWAbQth8J/LS7vs8+U9UlSX3SSxj8CFid5PB27v9k4H7gduCsNmY9cENb3trWadtvq6pq9XXtaqPlwArg28CdwIp2ddICOm8yb33pT02S1Kv50w2oqu1JtgDfAfYCdwOXAzcBm5N8utWubLtcCXwtyRiwh86LO1W1I8l1dIJkL3BeVf0GIMkHgW10rlTaVFU7Zu8pSpKmM20YAFTVBcAF+5QfonMl0L5jfwG8a4r7uRC4cJL6zcDNvfQiSZp9fgJZkmQYSJJ6PE2k2TGy8aaBPfYjF50xsMeWNPw8MpAkGQaSJMNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEj2GQZKFSbYk+X6SB5KcmOSoJLckebB9XdTGJsllScaS3JPk+K77Wd/GP5hkfVf9zUnubftcliSz/1QlSVPp9cjgUuC/qupPgDcBDwAbgVuragVwa1sHOA1Y0W4bgC8CJDkKuAB4C3ACcMFEgLQx7+/ab81Le1qSpJmYNgySHAm8FbgSoKp+VVVPA2uBq9qwq4Az2/Ja4OrquANYmORY4FTglqraU1VPAbcAa9q2V1XVHVVVwNVd9yVJ6oNejgyWA+PAV5LcneSKJEcAx1TV423ME8AxbXkJ8FjX/rtabX/1XZPUXyDJhiSjSUbHx8d7aF2S1ItewmA+cDzwxao6DvgZz58SAqD9Rl+z397vqqrLq2pVVa1avHjxXD+cJB00egmDXcCuqtre1rfQCYcft1M8tK9Ptu27gWVd+y9ttf3Vl05SlyT1ybRhUFVPAI8leX0rnQzcD2wFJq4IWg/c0Ja3Aue0q4pWA8+000nbgFOSLGpvHJ8CbGvbnk2yul1FdE7XfUmS+mB+j+P+DrgmyQLgIeB9dILkuiTnAo8C725jbwZOB8aAn7exVNWeJJ8C7mzjPllVe9ryB4CvAocB32w3SVKf9BQGVfVdYNUkm06eZGwB501xP5uATZPUR4E39tKLJGn2+QlkSZJhIEkyDCRJGAaSJAwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkScwgDJLMS3J3khvb+vIk25OMJbk2yYJWP6Stj7XtI133cX6r70xyald9TauNJdk4e09PktSLmRwZfAh4oGv9YuCSqnod8BRwbqufCzzV6pe0cSRZCawD3gCsAb7QAmYe8HngNGAlcHYbK0nqk57CIMlS4AzgirYe4G3AljbkKuDMtry2rdO2n9zGrwU2V9Uvq+phYAw4od3GquqhqvoVsLmNlST1Sa9HBp8DPgb8tq2/Gni6qva29V3Akra8BHgMoG1/po1/rr7PPlPVXyDJhiSjSUbHx8d7bF2SNJ1pwyDJO4Anq+quPvSzX1V1eVWtqqpVixcvHnQ7kvSyMb+HMScB70xyOnAo8CrgUmBhkvntt/+lwO42fjewDNiVZD5wJPDTrvqE7n2mqkuS+mDaI4OqOr+qllbVCJ03gG+rqvcAtwNntWHrgRva8ta2Ttt+W1VVq69rVxstB1YA3wbuBFa0q5MWtMfYOivPTpLUk16ODKbyj8DmJJ8G7gaubPUrga8lGQP20Hlxp6p2JLkOuB/YC5xXVb8BSPJBYBswD9hUVTteQl+SpBmaURhU1beAb7Xlh+hcCbTvmF8A75pi/wuBCyep3wzcPJNeJEmzx08gS5IMA0mSYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CShGEgScIwkCTRQxgkWZbk9iT3J9mR5EOtflSSW5I82L4uavUkuSzJWJJ7khzfdV/r2/gHk6zvqr85yb1tn8uSZC6erCRpcr0cGewFPlpVK4HVwHlJVgIbgVuragVwa1sHOA1Y0W4bgC9CJzyAC4C3ACcAF0wESBvz/q791rz0pyZJ6tW0YVBVj1fVd9ry/wIPAEuAtcBVbdhVwJlteS1wdXXcASxMcixwKnBLVe2pqqeAW4A1bdurquqOqirg6q77kiT1wYzeM0gyAhwHbAeOqarH26YngGPa8hLgsa7ddrXa/uq7JqlP9vgbkowmGR0fH59J65Kk/eg5DJK8EvgG8OGqerZ7W/uNvma5txeoqsuralVVrVq8ePFcP5wkHTR6CoMkr6ATBNdU1fWt/ON2iof29clW3w0s69p9aavtr750krokqU96uZoowJXAA1X12a5NW4GJK4LWAzd01c9pVxWtBp5pp5O2AackWdTeOD4F2Na2PZtkdXusc7ruS5LUB/N7GHMS8FfAvUm+22ofBy4CrktyLvAo8O627WbgdGAM+DnwPoCq2pPkU8Cdbdwnq2pPW/4A8FXgMOCb7SZJ6pNpw6Cq/geY6rr/kycZX8B5U9zXJmDTJPVR4I3T9SJJmht+AlmSZBhIkgwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEn09m8v9TIwsvGmgTzuIxedMZDHlTQzHhlIkgwDSZJhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CSxBD928ska4BLgXnAFVV10YBb0iwY1L/bBP/lpjQTQ3FkkGQe8HngNGAlcHaSlYPtSpIOHsNyZHACMFZVDwEk2QysBe4faFc6oA3qqMQjEh2IhiUMlgCPda3vAt6y76AkG4ANbfX/kux8kY93NPCTF7nvIBxI/R5IvcIc9JuLZ/PeXuCgn9859nLv94+m2jAsYdCTqrocuPyl3k+S0apaNQst9cWB1O+B1CvY71yz37k1m/0OxXsGwG5gWdf60laTJPXBsITBncCKJMuTLADWAVsH3JMkHTSG4jRRVe1N8kFgG51LSzdV1Y45fMiXfKqpzw6kfg+kXsF+55r9zq1Z6zdVNVv3JUk6QA3LaSJJ0gAZBpKkgysMkqxJsjPJWJKNg+4HIMmyJLcnuT/JjiQfavWjktyS5MH2dVGrJ8ll7Tnck+T4AfU9L8ndSW5s68uTbG99XdsuBCDJIW19rG0fGUCvC5NsSfL9JA8kOXGY5zfJR9r3wn1Jvp7k0GGa3ySbkjyZ5L6u2oznM8n6Nv7BJOv73O+/tu+He5L8R5KFXdvOb/3uTHJqV70vrx+T9du17aNJKsnRbX325reqDoobnTemfwi8FlgAfA9YOQR9HQsc35Z/H/gBnT/J8S/AxlbfCFzclk8HvgkEWA1sH1Df/wD8G3BjW78OWNeWvwT8bVv+APCltrwOuHYAvV4F/E1bXgAsHNb5pfMBzIeBw7rm9b3DNL/AW4Hjgfu6ajOaT+Ao4KH2dVFbXtTHfk8B5rfli7v6XdleGw4BlrfXjHn9fP2YrN9WX0bnIptHgaNne377+kM5yBtwIrCta/184PxB9zVJnzcAfwHsBI5ttWOBnW35y8DZXeOfG9fHHpcCtwJvA25s34g/6frhem6u2zfviW15fhuXPvZ6ZHtxzT71oZxfnv80/lFtvm4ETh22+QVG9nlxndF8AmcDX+6q/864ue53n21/CVzTln/ndWFifvv9+jFZv8AW4E3AIzwfBrM2vwfTaaLJ/uTFkgH1Mql2iH8csB04pqoeb5ueAI5py8PwPD4HfAz4bVt/NfB0Ve2dpKfn+m3bn2nj+2U5MA58pZ3WuiLJEQzp/FbVbuAzwI+Ax+nM110M7/xOmOl8DsP38YS/pvPbNQxpv0nWArur6nv7bJq1fg+mMBhqSV4JfAP4cFU9272tOtE+FNcAJ3kH8GRV3TXoXno0n84h9xer6jjgZ3ROYzxnyOZ3EZ0/0rgceA1wBLBmoE3N0DDN53SSfALYC1wz6F6mkuRw4OPAP83l4xxMYTC0f/IiySvoBME1VXV9K/84ybFt+7HAk60+6OdxEvDOJI8Am+mcKroUWJhk4kOM3T0912/bfiTw0z72uwvYVVXb2/oWOuEwrPP7duDhqhqvql8D19OZ82Gd3wkznc9BzzNJ3gu8A3hPCzD209cg+/1jOr8cfK/93C0FvpPkD/fT14z7PZjCYCj/5EWSAFcCD1TVZ7s2bQUmrgBYT+e9hIn6Oe0qgtXAM12H53Ouqs6vqqVVNUJnDm+rqvcAtwNnTdHvxPM4q43v22+NVfUE8FiS17fSyXT+NPpQzi+d00Orkxzevjcm+h3K+e0y0/ncBpySZFE7Gjql1foinX+m9THgnVX1865NW4F17Sqt5cAK4NsM8PWjqu6tqj+oqpH2c7eLzkUnTzCb8ztXb4AM443OO+8/oHNVwCcG3U/r6c/oHFLfA3y33U6nc973VuBB4L+Bo9r40PlHQD8E7gVWDbD3P+f5q4leS+eHZgz4d+CQVj+0rY+17a8dQJ9/Coy2Of5POldXDO38Av8MfB+4D/ganStbhmZ+ga/TeT/j1+2F6dwXM590ztWPtdv7+tzvGJ1z6hM/c1/qGv+J1u9O4LSuel9ePybrd5/tj/D8G8izNr/+OQpJ0kF1mkiSNAXDQJJkGEiSDANJEoaBJAnDQJKEYSBJAv4fwP79pB1y7u4AAAAASUVORK5CYII=\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "From the plot, we can observe that most of the sequences are between 0-200 characters long."
      ],
      "metadata": {
        "id": "gPoTKROlU4A2"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# for 95% of character length of sequences\n",
        "output_seq_character_length = int(np.percentile(character_length, 95))\n",
        "output_seq_character_length"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "OwnweXyiVOxY",
        "outputId": "b661edeb-f61d-44e2-b78b-d35df343df47"
      },
      "execution_count": 53,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "290"
            ]
          },
          "metadata": {},
          "execution_count": 53
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Vectorizing and character embeddings\n",
        "import string\n",
        "alphabet = string.ascii_lowercase + string.digits + string.punctuation\n",
        "alphabet"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "id": "Q2J_YsdgVqJu",
        "outputId": "3fd163e0-6fbd-48dd-cf12-ed3676b2efea"
      },
      "execution_count": 54,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'abcdefghijklmnopqrstuvwxyz0123456789!\"#$%&\\'()*+,-./:;<=>?@[\\\\]^_`{|}~'"
            ]
          },
          "metadata": {},
          "execution_count": 54
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# tokenizing the character\n",
        "NUM_CHAR_TOKENS = len(alphabet) + 2\n",
        "character_vectorizer = TextVectorization(max_tokens=NUM_CHAR_TOKENS,\n",
        "                                         output_sequence_length=output_seq_character_length,\n",
        "                                         standardize=\"lower_and_strip_punctuation\",\n",
        "                                         name=\"character_vectorizer\")\n",
        "character_vectorizer.adapt(train_char)\n",
        "\n",
        "# Sample of character vocabulary\n",
        "character_vocab = character_vectorizer.get_vocabulary()\n",
        "print(f\"Total number of unique characters: {len(character_vocab)}\")\n",
        "print(f\"10 most common characters: {character_vocab[:10]}\")\n",
        "print(f\"10 least common characters: {character_vocab[-10:]}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "XAoLqbnqWHTd",
        "outputId": "a652f7df-8ff6-44a9-c768-00bb34ab2c4e"
      },
      "execution_count": 55,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Total number of unique characters: 28\n",
            "10 most common characters: ['', '[UNK]', 'e', 't', 'i', 'a', 'n', 'o', 'r', 's']\n",
            "10 least common characters: ['g', 'y', 'w', 'v', 'b', 'k', 'x', 'z', 'q', 'j']\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Sample of character vectorizer\n",
        "import random\n",
        "sample_train_character = random.choice(train_char)\n",
        "print(f\" Text (in characters):\\n{sample_train_character}\")\n",
        "print(f\"\\nCharacter length: {len(sample_train_character.split())}\")\n",
        "vectorized__character = character_vectorizer([sample_train_character])\n",
        "print(f\"\\nCharacter Vectorized:\\n{vectorized__character}\")\n",
        "print(f\"\\nVectorized character length: {len(vectorized__character[0])}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "e68OOa3GX2sx",
        "outputId": "c7b80a93-5e2a-4919-9c6d-114cab2612d5"
      },
      "execution_count": 56,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            " Text (in characters):\n",
            "p r o b a b l y   ,   m o r e   t i m e   w i l l   b e   r e q u i r e d   t o   a s s e s s   c h a n g e s   i n   q u a l i t y   o f   l i f e   .\n",
            "\n",
            "Character length: 62\n",
            "\n",
            "Character Vectorized:\n",
            "[[14  8  7 22  5 22 12 19 15  7  8  2  3  4 15  2 20  4 12 12 22  2  8  2\n",
            "  26 16  4  8  2 10  3  7  5  9  9  2  9  9 11 13  5  6 18  2  9  4  6 26\n",
            "  16  5 12  4  3 19  7 17 12  4 17  2  0  0  0  0  0  0  0  0  0  0  0  0\n",
            "   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0\n",
            "   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0\n",
            "   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0\n",
            "   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0\n",
            "   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0\n",
            "   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0\n",
            "   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0\n",
            "   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0\n",
            "   0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0\n",
            "   0  0]]\n",
            "\n",
            "Vectorized character length: 290\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Character embedding layer\n",
        "character_embed = layers.Embedding(input_dim=NUM_CHAR_TOKENS,\n",
        "                                   output_dim=25,\n",
        "                                   mask_zero=False,\n",
        "                                   name=\"character_embed\")\n",
        "\n",
        "# Sample\n",
        "print(f\"Character text before embeddings:\\n{sample_train_character}\\n\")\n",
        "character_embed_example = character_embed(character_vectorizer([sample_train_character]))\n",
        "print(f\"Embedded characters:\\n{character_embed_example}\\n\")\n",
        "print(f\"Shape of embedded character: {character_embed_example.shape}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "sRrNfFJ_Zyeo",
        "outputId": "61f465ad-c1a5-4321-a27e-764e0987e4a0"
      },
      "execution_count": 57,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Character text before embeddings:\n",
            "p r o b a b l y   ,   m o r e   t i m e   w i l l   b e   r e q u i r e d   t o   a s s e s s   c h a n g e s   i n   q u a l i t y   o f   l i f e   .\n",
            "\n",
            "Embedded characters:\n",
            "[[[ 0.04213986  0.03748259  0.04720943 ...  0.01618828  0.00654129\n",
            "    0.04588342]\n",
            "  [-0.01738371 -0.03987639 -0.0444983  ... -0.00542194 -0.0436322\n",
            "   -0.02087811]\n",
            "  [ 0.04768093  0.03476207  0.03694859 ...  0.04528984  0.01601002\n",
            "   -0.01545056]\n",
            "  ...\n",
            "  [-0.04275223  0.01708314 -0.02383623 ...  0.04501892 -0.03379728\n",
            "   -0.03942678]\n",
            "  [-0.04275223  0.01708314 -0.02383623 ...  0.04501892 -0.03379728\n",
            "   -0.03942678]\n",
            "  [-0.04275223  0.01708314 -0.02383623 ...  0.04501892 -0.03379728\n",
            "   -0.03942678]]]\n",
            "\n",
            "Shape of embedded character: (1, 290, 25)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Modelling model_4 with Conv1D\n",
        "inputs = layers.Input(shape=(1,), dtype=\"string\")\n",
        "character_vectors = character_vectorizer(inputs)\n",
        "character_embeddings = character_embed(character_vectors)\n",
        "x = layers.Conv1D(63, kernel_size=5, padding=\"same\", activation=\"relu\")(character_embeddings)\n",
        "x = layers.GlobalMaxPool1D()(x)\n",
        "outputs = layers.Dense(num_classes, activation=\"softmax\")(x)\n",
        "model_4 = tf.keras.Model(inputs=inputs,\n",
        "                         outputs=outputs,\n",
        "                         name=\"model_4_conv1D_character_embedding\")\n",
        "model_4.compile(loss=\"categorical_crossentropy\",\n",
        "                optimizer=tf.keras.optimizers.Adam(),\n",
        "                metrics=[\"accuracy\"])\n",
        "\n",
        "# Character dataset\n",
        "train_character_dataset = tf.data.Dataset.from_tensor_slices((train_char, train_labels_one_hot)).batch(32).prefetch(tf.data.AUTOTUNE)\n",
        "val_character_dataset = tf.data.Dataset.from_tensor_slices((val_char, val_labels_one_hot)).batch(32).prefetch(tf.data.AUTOTUNE)\n",
        "\n",
        "model_4_history = model_4.fit(train_character_dataset,\n",
        "                              epochs=3,\n",
        "                              steps_per_epoch=int(0.1 * len(train_character_dataset)),\n",
        "                              validation_data=val_character_dataset,\n",
        "                              validation_steps=int(0.1 * len(val_character_dataset)))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Iq1zFCxnbIio",
        "outputId": "f5240f0f-795b-41af-8733-0c12b30fc251"
      },
      "execution_count": 58,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch 1/3\n",
            "562/562 [==============================] - 14s 23ms/step - loss: 1.2650 - accuracy: 0.4895 - val_loss: 1.0490 - val_accuracy: 0.5928\n",
            "Epoch 2/3\n",
            "562/562 [==============================] - 13s 23ms/step - loss: 1.0083 - accuracy: 0.6011 - val_loss: 0.9411 - val_accuracy: 0.6436\n",
            "Epoch 3/3\n",
            "562/562 [==============================] - 13s 22ms/step - loss: 0.9173 - accuracy: 0.6428 - val_loss: 0.8664 - val_accuracy: 0.6702\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "model_4.summary()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Jb9oYWOeduEf",
        "outputId": "e6d72db8-1f04-4ac5-bb4e-6138c96da84f"
      },
      "execution_count": 59,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Model: \"model_4_conv1D_character_embedding\"\n",
            "_________________________________________________________________\n",
            " Layer (type)                Output Shape              Param #   \n",
            "=================================================================\n",
            " input_3 (InputLayer)        [(None, 1)]               0         \n",
            "                                                                 \n",
            " character_vectorizer (TextV  (None, 290)              0         \n",
            " ectorization)                                                   \n",
            "                                                                 \n",
            " character_embed (Embedding)  (None, 290, 25)          1750      \n",
            "                                                                 \n",
            " conv1d_1 (Conv1D)           (None, 290, 63)           7938      \n",
            "                                                                 \n",
            " global_max_pooling1d (Globa  (None, 63)               0         \n",
            " lMaxPooling1D)                                                  \n",
            "                                                                 \n",
            " dense_3 (Dense)             (None, 5)                 320       \n",
            "                                                                 \n",
            "=================================================================\n",
            "Total params: 10,008\n",
            "Trainable params: 10,008\n",
            "Non-trainable params: 0\n",
            "_________________________________________________________________\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# model_4 evaluation\n",
        "model_4.evaluate(val_character_dataset)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QEPQWPz3efVD",
        "outputId": "626828ce-3f10-4f6e-a15d-9a61c404fddd"
      },
      "execution_count": 60,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "945/945 [==============================] - 7s 7ms/step - loss: 0.8812 - accuracy: 0.6614\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[0.8812289237976074, 0.6614259481430054]"
            ]
          },
          "metadata": {},
          "execution_count": 60
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Predictions\n",
        "model_4_pred_probs = model_4.predict(val_character_dataset)\n",
        "# Predictions to classes\n",
        "model_4_preds = tf.argmax(model_4_pred_probs, axis=1)\n",
        "# Results\n",
        "model_4_results = calculate_results(y_true=val_labels_encoded,\n",
        "                                    y_pred=model_4_preds)\n",
        "model_4_results"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Qa64jOuefKOK",
        "outputId": "884a5aea-3399-4905-e28c-83b5d0389ee9"
      },
      "execution_count": 61,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'accuracy': 66.14259234741162,\n",
              " 'f1': 0.6526517671886255,\n",
              " 'precision': 0.656632048558338,\n",
              " 'recall': 0.6614259234741162}"
            ]
          },
          "metadata": {},
          "execution_count": 61
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Not good enough at 66% accuracy, compared to model_1, model_2 & model_3."
      ],
      "metadata": {
        "id": "Zs8Jc3OwgIkh"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# model_5 with pretrained tokens and character embeddings"
      ],
      "metadata": {
        "id": "LJZOTdedgW94"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Tokens inputs\n",
        "token_inputs = layers.Input(shape=[], dtype=tf.string, name=\"token_input\")\n",
        "token_embeddings = tf_hub_embedding_layer(token_inputs)\n",
        "token_output = layers.Dense(128, activation=\"relu\")(token_embeddings)\n",
        "token_model = tf.keras.Model(inputs=token_inputs,\n",
        "                             outputs=token_output)\n",
        "# Character inputs\n",
        "character_inputs = layers.Input(shape=(1,), dtype=tf.string, name=\"character_input\")\n",
        "character_vectors = character_vectorizer(character_inputs)\n",
        "character_embeddings = character_embed(character_vectors)\n",
        "character_biLSTM = layers.Bidirectional(layers.LSTM(25))(character_embeddings)\n",
        "character_model = tf.keras.Model(inputs=character_inputs,\n",
        "                                 outputs=character_biLSTM)\n",
        "# token & character concatenation \n",
        "token_character_concat = layers.Concatenate(name=\"token_character_hybrid\")([token_model.output,\n",
        "                                                                            character_model.output])\n",
        "# Output layers\n",
        "combined_dropout = layers.Dropout(0.5)(token_character_concat)\n",
        "combined_dense = layers.Dense(200, activation=\"relu\")(combined_dropout)\n",
        "final_dropout = layers.Dropout(0.5)(combined_dense)\n",
        "output_layer = layers.Dense(num_classes, activation=\"softmax\")(final_dropout)\n",
        "\n",
        "model_5 = tf.keras.Model(inputs=[token_model.input, character_model.input],\n",
        "                         outputs=output_layer,\n",
        "                         name=\"model_5_hybrid\")\n",
        "model_5.compile(loss=\"categorical_crossentropy\",\n",
        "                optimizer=tf.keras.optimizers.Adam(),\n",
        "                metrics=[\"accuracy\"])\n",
        "\n",
        "# Combining both of the training dataset & batching\n",
        "train_character_token_data = tf.data.Dataset.from_tensor_slices((train_sentences, train_char))\n",
        "train_character_token_labels = tf.data.Dataset.from_tensor_slices(train_labels_one_hot)\n",
        "train_character_token_dataset = tf.data.Dataset.zip((train_character_token_data, train_character_token_labels))\n",
        "train_character_token_dataset = train_character_token_dataset.batch(32).prefetch(tf.data.AUTOTUNE)\n",
        "\n",
        "# Combining both of the validation dataset & batching\n",
        "val_character_token_data = tf.data.Dataset.from_tensor_slices((val_sentences, val_char))\n",
        "val_character_token_labels = tf.data.Dataset.from_tensor_slices(val_labels_one_hot)\n",
        "val_character_token_dataset = tf.data.Dataset.zip((val_character_token_data, val_character_token_labels))\n",
        "val_character_token_dataset = val_character_token_dataset.batch(32).prefetch(tf.data.AUTOTUNE)\n",
        "\n",
        "# Fit\n",
        "model_5_history = model_5.fit(train_character_token_dataset,                           \n",
        "                              steps_per_epoch=int(0.1 * len(train_character_token_dataset)),\n",
        "                              epochs=3,\n",
        "                              validation_data=val_character_token_dataset,\n",
        "                              validation_steps=int(0.1 * len(val_character_token_dataset)))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_SK8rkO3glWp",
        "outputId": "8c5edea8-2c99-44b1-df81-0e762d8a98c9"
      },
      "execution_count": 62,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch 1/3\n",
            "562/562 [==============================] - 137s 236ms/step - loss: 0.9741 - accuracy: 0.6118 - val_loss: 0.7921 - val_accuracy: 0.6865\n",
            "Epoch 2/3\n",
            "562/562 [==============================] - 112s 198ms/step - loss: 0.7974 - accuracy: 0.6895 - val_loss: 0.7197 - val_accuracy: 0.7251\n",
            "Epoch 3/3\n",
            "562/562 [==============================] - 115s 205ms/step - loss: 0.7714 - accuracy: 0.7062 - val_loss: 0.6938 - val_accuracy: 0.7400\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "model_5.summary()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "JVi6YtEIpTCe",
        "outputId": "b5c0ab67-f516-4ff8-a514-ec02b62c27a0"
      },
      "execution_count": 63,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Model: \"model_5_hybrid\"\n",
            "__________________________________________________________________________________________________\n",
            " Layer (type)                   Output Shape         Param #     Connected to                     \n",
            "==================================================================================================\n",
            " character_input (InputLayer)   [(None, 1)]          0           []                               \n",
            "                                                                                                  \n",
            " token_input (InputLayer)       [(None,)]            0           []                               \n",
            "                                                                                                  \n",
            " character_vectorizer (TextVect  (None, 290)         0           ['character_input[0][0]']        \n",
            " orization)                                                                                       \n",
            "                                                                                                  \n",
            " universal_sentence_encoder (Ke  (None, 512)         256797824   ['token_input[0][0]']            \n",
            " rasLayer)                                                                                        \n",
            "                                                                                                  \n",
            " character_embed (Embedding)    (None, 290, 25)      1750        ['character_vectorizer[1][0]']   \n",
            "                                                                                                  \n",
            " dense_4 (Dense)                (None, 128)          65664       ['universal_sentence_encoder[1][0\n",
            "                                                                 ]']                              \n",
            "                                                                                                  \n",
            " bidirectional (Bidirectional)  (None, 50)           10200       ['character_embed[1][0]']        \n",
            "                                                                                                  \n",
            " token_character_hybrid (Concat  (None, 178)         0           ['dense_4[0][0]',                \n",
            " enate)                                                           'bidirectional[0][0]']          \n",
            "                                                                                                  \n",
            " dropout (Dropout)              (None, 178)          0           ['token_character_hybrid[0][0]'] \n",
            "                                                                                                  \n",
            " dense_5 (Dense)                (None, 200)          35800       ['dropout[0][0]']                \n",
            "                                                                                                  \n",
            " dropout_1 (Dropout)            (None, 200)          0           ['dense_5[0][0]']                \n",
            "                                                                                                  \n",
            " dense_6 (Dense)                (None, 5)            1005        ['dropout_1[0][0]']              \n",
            "                                                                                                  \n",
            "==================================================================================================\n",
            "Total params: 256,912,243\n",
            "Trainable params: 114,419\n",
            "Non-trainable params: 256,797,824\n",
            "__________________________________________________________________________________________________\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# model_5 evaluation\n",
        "model_5.evaluate(val_character_token_dataset)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fyJJjn4YsswX",
        "outputId": "4f92a509-d899-4dda-8d1a-f62f62eb5f26"
      },
      "execution_count": 64,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "945/945 [==============================] - 40s 43ms/step - loss: 0.6966 - accuracy: 0.7342\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[0.6966075301170349, 0.7342115640640259]"
            ]
          },
          "metadata": {},
          "execution_count": 64
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# model_5 predictions\n",
        "model_5_pred_probs = model_5.predict(val_character_token_dataset)\n",
        "model_5_preds = tf.argmax(model_5_pred_probs, axis=1)\n",
        "model_5_results = calculate_results(y_true=val_labels_encoded,\n",
        "                                    y_pred=model_5_preds)\n",
        "model_5_results"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "OayKaA3Rs68f",
        "outputId": "9176bc15-6a10-41b3-bed2-142422b76e8d"
      },
      "execution_count": 65,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'accuracy': 73.42115715609691,\n",
              " 'f1': 0.7316927135268578,\n",
              " 'precision': 0.7361133381605339,\n",
              " 'recall': 0.7342115715609692}"
            ]
          },
          "metadata": {},
          "execution_count": 65
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "model_5 performs better than model_4 "
      ],
      "metadata": {
        "id": "ia_CZHCYuEtp"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# model_6 with transfer learning and positional embeddings"
      ],
      "metadata": {
        "id": "yvwernTtufoD"
      }
    },
    {
      "cell_type": "code",
      "source": [
        " # EDA before modelling\n",
        " train_df[\"line_number\"].value_counts()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4WR--gjqupKK",
        "outputId": "bc150f2c-e89a-41c4-b9c3-eeafed3c68f1"
      },
      "execution_count": 66,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0     15000\n",
              "1     15000\n",
              "2     15000\n",
              "3     15000\n",
              "4     14992\n",
              "5     14949\n",
              "6     14758\n",
              "7     14279\n",
              "8     13346\n",
              "9     11981\n",
              "10    10041\n",
              "11     7892\n",
              "12     5853\n",
              "13     4152\n",
              "14     2835\n",
              "15     1861\n",
              "16     1188\n",
              "17      751\n",
              "18      462\n",
              "19      286\n",
              "20      162\n",
              "21      101\n",
              "22       66\n",
              "23       33\n",
              "24       22\n",
              "25       14\n",
              "26        7\n",
              "27        4\n",
              "28        3\n",
              "29        1\n",
              "30        1\n",
              "Name: line_number, dtype: int64"
            ]
          },
          "metadata": {},
          "execution_count": 66
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "train_df.line_number.plot.hist();"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 265
        },
        "id": "GnNywG80u5_d",
        "outputId": "eb63de0c-9196-4df4-cb48-3a4e64a59a29"
      },
      "execution_count": 67,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZEAAAD4CAYAAAAtrdtxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAASwElEQVR4nO3df9CdZX3n8ffHAAVtFShZliHQYM3UTV2rGIGO7a6LIwZphXbVwtQ16zCmM+KMTveH0eks1pYZ3NkWS0fd0pJpcNtGqlayBYeNiv3xBz+CoAiU8hTDkoiQGhCpFjb43T/O9cAxPnlyciXnOc/J837NnHnu+3tf97mva+7kfOb+ce6TqkKSpB7Pm3QHJEnTyxCRJHUzRCRJ3QwRSVI3Q0SS1O2ISXdgoZ1wwgm1cuXKSXdDkqbG7bff/o9VtXyuZUsuRFauXMm2bdsm3Q1JmhpJHtzXMk9nSZK6GSKSpG6GiCSpmyEiSepmiEiSuhkikqRuhogkqZshIknqZohIkrotuW+sH4yVG66fdBcW3PbLz5t0FyQtYh6JSJK6GSKSpG6GiCSpmyEiSepmiEiSuhkikqRuhogkqZshIknqZohIkroZIpKkboaIJKmbz87SvCb1vDCf2SVNB49EJEndDBFJUjdDRJLUzRCRJHUzRCRJ3QwRSVI3Q0SS1G3sIZJkWZI7kvxlmz8tyS1JZpJ8MslRrf4jbX6mLV859B7vb/X7krxhqL621WaSbBj3WCRJP2ghjkTeA9w7NP9h4IqqegnwGHBxq18MPNbqV7R2JFkNXAj8NLAW+FgLpmXAR4FzgdXARa2tJGmBjDVEkqwAzgP+qM0HOBv4VGuyCbigTZ/f5mnLX9fanw9srqqnqurrwAxwRnvNVNUDVfU0sLm1lSQtkHEfiXwE+K/A99v8jwOPV9WeNr8DOLlNnww8BNCWf7u1f7a+1zr7qv+QJOuTbEuybdeuXQc7JklSM7YQSfILwKNVdfu4tjGqqrqqqtZU1Zrly5dPujuSdNgY5wMYXwO8KckbgaOBFwK/Bxyb5Ih2tLEC2Nna7wROAXYkOQJ4EfCtofqs4XX2VZckLYCxHYlU1furakVVrWRwYfyLVfWrwE3Am1uzdcB1bXpLm6ct/2JVVatf2O7eOg1YBdwK3Aasand7HdW2sWVc45Ek/bBJPAr+fcDmJL8N3AFc3epXA59IMgPsZhAKVNXdSa4F7gH2AJdU1TMASd4N3AgsAzZW1d0LOhJJWuIWJESq6kvAl9r0AwzurNq7zT8Db9nH+pcBl81RvwG44RB2VZJ0APzGuiSpmyEiSepmiEiSuhkikqRuhogkqZshIknqZohIkroZIpKkboaIJKmbISJJ6maISJK6GSKSpG6GiCSpmyEiSepmiEiSuhkikqRuhogkqZshIknqZohIkroZIpKkboaIJKmbISJJ6maISJK6GSKSpG6GiCSpmyEiSepmiEiSuhkikqRuhogkqZshIknqZohIkroZIpKkboaIJKmbISJJ6maISJK6GSKSpG6GiCSp29hCJMnRSW5N8pUkdyf5zVY/LcktSWaSfDLJUa3+I21+pi1fOfRe72/1+5K8Yai+ttVmkmwY11gkSXMb55HIU8DZVfUzwCuAtUnOAj4MXFFVLwEeAy5u7S8GHmv1K1o7kqwGLgR+GlgLfCzJsiTLgI8C5wKrgYtaW0nSAhlbiNTAk232yPYq4GzgU62+CbigTZ/f5mnLX5ckrb65qp6qqq8DM8AZ7TVTVQ9U1dPA5tZWkrRAjhjnm7ejhduBlzA4avgH4PGq2tOa7ABObtMnAw8BVNWeJN8GfrzVbx562+F1HtqrfuY++rEeWA9w6qmnHtygtCBWbrh+Ytvefvl5E9u2NG3GemG9qp6pqlcAKxgcObx0nNubpx9XVdWaqlqzfPnySXRBkg5LC3J3VlU9DtwE/CxwbJLZI6AVwM42vRM4BaAtfxHwreH6Xuvsqy5JWiDjvDtreZJj2/QxwOuBexmEyZtbs3XAdW16S5unLf9iVVWrX9ju3joNWAXcCtwGrGp3ex3F4OL7lnGNR5L0w8Z5TeQkYFO7LvI84Nqq+ssk9wCbk/w2cAdwdWt/NfCJJDPAbgahQFXdneRa4B5gD3BJVT0DkOTdwI3AMmBjVd09xvFIkvYythCpqq8Cr5yj/gCD6yN71/8ZeMs+3usy4LI56jcANxx0ZyVJXUY6nZXkX4+7I5Kk6TPqNZGPtW+fvyvJi8baI0nS1BgpRKrq54FfZXA31O1J/jTJ68faM0nSojfy3VlVdT/wG8D7gH8LXJnk75L88rg6J0la3Ea9JvLyJFcwuEX3bOAXq+pftekrxtg/SdIiNurdWb8P/BHwgar63myxqr6R5DfG0jNJ0qI3aoicB3xv6PsZzwOOrqrvVtUnxtY7SdKiNuo1kc8DxwzNP7/VJElL2KghcvTQY91p088fT5ckSdNi1BD5pySnz84keRXwvXnaS5KWgFGvibwX+PMk3wAC/EvgV8bWK0nSVBgpRKrqtiQvBX6qle6rqv83vm5JkqbBgTyA8dXAyrbO6UmoqmvG0itJ0lQYKUSSfAL4SeBO4JlWLsAQkaQlbNQjkTXA6vYjUZIkAaPfnfU1BhfTJUl61qhHIicA9yS5FXhqtlhVbxpLryRJU2HUEPngODshSZpOo97i+1dJfgJYVVWfT/J8Br9rLklawkZ9FPw7gU8Bf9BKJwOfHVenJEnTYdQL65cArwGegGd/oOpfjKtTkqTpMGqIPFVVT8/OJDmCwfdEJElL2Kgh8ldJPgAc035b/c+B/z2+bkmSpsGoIbIB2AXcBfwacAOD31uXJC1ho96d9X3gD9tLkiRg9GdnfZ05roFU1YsPeY8kSVPjQJ6dNeto4C3A8Ye+O5KkaTLSNZGq+tbQa2dVfQQ4b8x9kyQtcqOezjp9aPZ5DI5MDuS3SCRJh6FRg+B3hqb3ANuBtx7y3kiSpsqod2f9u3F3RJI0fUY9nfXr8y2vqt89NN2RJE2TA7k769XAljb/i8CtwP3j6JQkaTqMGiIrgNOr6jsAST4IXF9VbxtXxyRJi9+ojz05EXh6aP7pVpMkLWGjHolcA9ya5C/a/AXApvF0SZI0LUa9O+uyJJ8Dfr6V3lFVd4yvW5KkaTDq6SyA5wNPVNXvATuSnDZf4ySnJLkpyT1J7k7ynlY/PsnWJPe3v8e1epJcmWQmyVeHv+CYZF1rf3+SdUP1VyW5q61zZZIc0OglSQdl1J/HvRR4H/D+VjoS+F/7WW0P8J+qajVwFnBJktUMHiv/hapaBXyhzQOcC6xqr/XAx9u2jwcuBc4EzgAunQ2e1uadQ+utHWU8kqRDY9QjkV8C3gT8E0BVfQP4sflWqKqHq+rLbfo7wL0Mfpv9fJ67nrKJwfUVWv2aGrgZODbJScAbgK1VtbuqHgO2AmvbshdW1c1VVQyu28y+lyRpAYwaIk+3D+oCSPKCA9lIkpXAK4FbgBOr6uG26Js8d5fXycBDQ6vtaLX56jvmqM+1/fVJtiXZtmvXrgPpuiRpHqOGyLVJ/oDB0cE7gc8z4g9UJflR4NPAe6vqieFlw8E0TlV1VVWtqao1y5cvH/fmJGnJ2O/dWe1i9SeBlwJPAD8F/Leq2jrCukcyCJA/qarPtPIjSU6qqofbKalHW30ncMrQ6itabSfw2r3qX2r1FXO0lyQtkP0eibSjhRuqamtV/Zeq+s8jBkiAq4F793q21hZg9g6rdcB1Q/W3t7u0zgK+3U573Qick+S4dkH9HODGtuyJJGe1bb196L0kSQtg1C8bfjnJq6vqtgN479cA/wG4K8mdrfYB4HIGp8cuBh7kuUfK3wC8EZgBvgu8A6Cqdif5LWB22x+qqt1t+l3AHwPHAJ9rL0nSAhk1RM4E3pZkO4M7tMLgIOXl+1qhqv62tZvL6+ZoX8Al+3ivjcDGOerbgJftr/OSpPGYN0SSnFpV/5fBbbaSJP2A/R2JfJbB03sfTPLpqvr3C9EpSdJ02N+F9eHTUS8eZ0ckSdNnfyFS+5iWJGm/p7N+JskTDI5IjmnT8NyF9ReOtXeSpEVt3hCpqmUL1RFJ0vQ5kEfBS5L0AwwRSVI3Q0SS1M0QkSR1M0QkSd0MEUlSN0NEktTNEJEkdTNEJEndDBFJUjdDRJLUzRCRJHUzRCRJ3QwRSVI3Q0SS1M0QkSR1M0QkSd0MEUlSN0NEktTNEJEkdTNEJEndjph0B6TFZuWG6yey3e2XnzeR7UoHwyMRSVI3Q0SS1M0QkSR1M0QkSd0MEUlSN0NEktTNEJEkdTNEJEndDBFJUrexhUiSjUkeTfK1odrxSbYmub/9Pa7Vk+TKJDNJvprk9KF11rX29ydZN1R/VZK72jpXJsm4xiJJmts4j0T+GFi7V20D8IWqWgV8oc0DnAusaq/1wMdhEDrApcCZwBnApbPB09q8c2i9vbclSRqzsYVIVf01sHuv8vnApja9CbhgqH5NDdwMHJvkJOANwNaq2l1VjwFbgbVt2Qur6uaqKuCaofeSJC2Qhb4mcmJVPdymvwmc2KZPBh4aarej1ear75ijPqck65NsS7Jt165dBzcCSdKzJnZhvR1B1AJt66qqWlNVa5YvX74Qm5SkJWGhQ+SRdiqK9vfRVt8JnDLUbkWrzVdfMUddkrSAFjpEtgCzd1itA64bqr+93aV1FvDtdtrrRuCcJMe1C+rnADe2ZU8kOavdlfX2ofeSJC2Qsf0oVZI/A14LnJBkB4O7rC4Hrk1yMfAg8NbW/AbgjcAM8F3gHQBVtTvJbwG3tXYfqqrZi/XvYnAH2DHA59pLkrSAxhYiVXXRPha9bo62BVyyj/fZCGyco74NeNnB9FGSdHD8xrokqZshIknqZohIkroZIpKkboaIJKmbISJJ6maISJK6GSKSpG6GiCSpmyEiSepmiEiSuhkikqRuhogkqZshIknqZohIkroZIpKkboaIJKmbISJJ6maISJK6GSKSpG6GiCSpmyEiSep2xKQ7IGlg5YbrJ7Ld7ZefN5Ht6vDgkYgkqZshIknqZohIkroZIpKkboaIJKmbISJJ6maISJK6GSKSpG6GiCSpmyEiSepmiEiSuhkikqRuhogkqZtP8ZWWuEk9PRh8gvDhYOqPRJKsTXJfkpkkGybdH0laSqY6RJIsAz4KnAusBi5KsnqyvZKkpWPaT2edAcxU1QMASTYD5wP3TLRXkkbiD3FNv2kPkZOBh4bmdwBn7t0oyXpgfZt9Msl9nds7AfjHznUXm8NlLIfLOMCxLJh8eOSmi3ocB+hgxvIT+1ow7SEykqq6CrjqYN8nybaqWnMIujRxh8tYDpdxgGNZjA6XccD4xjLV10SAncApQ/MrWk2StACmPURuA1YlOS3JUcCFwJYJ90mSloypPp1VVXuSvBu4EVgGbKyqu8e4yYM+JbaIHC5jOVzGAY5lMTpcxgFjGkuqahzvK0laAqb9dJYkaYIMEUlSN0NkBIfTo1WSbE9yV5I7k2ybdH8ORJKNSR5N8rWh2vFJtia5v/09bpJ9HNU+xvLBJDvbvrkzyRsn2cdRJDklyU1J7klyd5L3tPrU7Zd5xjKN++XoJLcm+Uoby2+2+mlJbmmfZZ9sNyQd3La8JjK/9miVvwdez+DLjLcBF1XVVH4rPsl2YE1VTd0XqJL8G+BJ4Jqqelmr/Xdgd1Vd3gL+uKp63yT7OYp9jOWDwJNV9T8m2bcDkeQk4KSq+nKSHwNuBy4A/iNTtl/mGctbmb79EuAFVfVkkiOBvwXeA/w68Jmq2pzkfwJfqaqPH8y2PBLZv2cfrVJVTwOzj1bRAquqvwZ271U+H9jUpjcx+E+/6O1jLFOnqh6uqi+36e8A9zJ4ksTU7Zd5xjJ1auDJNntkexVwNvCpVj8k+8UQ2b+5Hq0ylf+wmgL+T5Lb2+Ngpt2JVfVwm/4mcOIkO3MIvDvJV9vprkV/CmhYkpXAK4FbmPL9stdYYAr3S5JlSe4EHgW2Av8APF5Ve1qTQ/JZZogsPT9XVaczePLxJe20ymGhBudmp/n87MeBnwReATwM/M5kuzO6JD8KfBp4b1U9Mbxs2vbLHGOZyv1SVc9U1SsYPMnjDOCl49iOIbJ/h9WjVapqZ/v7KPAXDP5xTbNH2rns2XPaj064P92q6pH2H//7wB8yJfumnXP/NPAnVfWZVp7K/TLXWKZ1v8yqqseBm4CfBY5NMvsl80PyWWaI7N9h82iVJC9oFwxJ8gLgHOBr86+16G0B1rXpdcB1E+zLQZn90G1+iSnYN+0C7tXAvVX1u0OLpm6/7GssU7pflic5tk0fw+DGoHsZhMmbW7NDsl+8O2sE7Za+j/Dco1Uum3CXuiR5MYOjDxg88uZPp2ksSf4MeC2DR1o/AlwKfBa4FjgVeBB4a1Ut+gvW+xjLaxmcMilgO/BrQ9cVFqUkPwf8DXAX8P1W/gCDawlTtV/mGctFTN9+eTmDC+fLGBwsXFtVH2qfAZuB44E7gLdV1VMHtS1DRJLUy9NZkqRuhogkqZshIknqZohIkroZIpKkboaIJKmbISJJ6vb/AVwSphAAsBgmAAAAAElFTkSuQmCC\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# One hot encoding\n",
        "train_line_numbers_one_hot = tf.one_hot(train_df[\"line_number\"].to_numpy(), depth=15)\n",
        "val_line_numbers_one_hot = tf.one_hot(val_df[\"line_number\"].to_numpy(), depth=15)\n",
        "test_line_numbers_one_hot = tf.one_hot(test_df[\"line_number\"].to_numpy(), depth=15)\n",
        "train_line_numbers_one_hot, train_line_numbers_one_hot[:25]"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "cgvmqKdovAZ5",
        "outputId": "d6233fd0-6d76-4cab-c83d-8e98fc2df134"
      },
      "execution_count": 68,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(<tf.Tensor: shape=(180040, 15), dtype=float32, numpy=\n",
              " array([[1., 0., 0., ..., 0., 0., 0.],\n",
              "        [0., 1., 0., ..., 0., 0., 0.],\n",
              "        [0., 0., 1., ..., 0., 0., 0.],\n",
              "        ...,\n",
              "        [0., 0., 0., ..., 0., 0., 0.],\n",
              "        [0., 0., 0., ..., 0., 0., 0.],\n",
              "        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>,\n",
              " <tf.Tensor: shape=(25, 15), dtype=float32, numpy=\n",
              " array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
              "        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
              "        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
              "        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],\n",
              "        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
              "        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
              "        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
              "        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],\n",
              "        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
              "        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],\n",
              "       dtype=float32)>)"
            ]
          },
          "metadata": {},
          "execution_count": 68
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Unique numbers of lines\n",
        "train_df[\"total_lines\"].value_counts()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "JAZ0t8O5v2ra",
        "outputId": "d7a91e8c-0096-45ea-9f75-22e75221f7b0"
      },
      "execution_count": 69,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "11    24468\n",
              "10    23639\n",
              "12    22113\n",
              "9     19400\n",
              "13    18438\n",
              "14    14610\n",
              "8     12285\n",
              "15    10768\n",
              "7      7464\n",
              "16     7429\n",
              "17     5202\n",
              "6      3353\n",
              "18     3344\n",
              "19     2480\n",
              "20     1281\n",
              "5      1146\n",
              "21      770\n",
              "22      759\n",
              "23      264\n",
              "4       215\n",
              "24      200\n",
              "25      182\n",
              "26       81\n",
              "28       58\n",
              "3        32\n",
              "30       31\n",
              "27       28\n",
              "Name: total_lines, dtype: int64"
            ]
          },
          "metadata": {},
          "execution_count": 69
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "train_df.total_lines.plot.hist();"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 267
        },
        "id": "1g9b1HIGwDM1",
        "outputId": "a56eb1b7-f113-40b7-8211-65a0eee26185"
      },
      "execution_count": 70,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZEAAAD6CAYAAABgZXp6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAXpUlEQVR4nO3df7BfdX3n8efLRCpSkVDSLJNgg21Gl7r+gCvg1HatjCHg1tBdl4WtS5ZhiDNgV8f9QXQ6i8Uyk+5spdJatqlkTVwV8SfZEppGxHb7Bz+CIAjo5IqwJAJJDRDRFhZ97x/fz5Wv4ebyzbn53i/35vmY+c49530+55zPZ74TXpxzPt/vN1WFJEldvGjUHZAkzV6GiCSpM0NEktSZISJJ6swQkSR1ZohIkjobWogkeVWSO/tee5O8L8nRSbYm2d7+Lmjtk+TKJONJ7kpyYt+xVrX225Os6quflOTuts+VSTKs8UiSnisz8TmRJPOAncApwMXAnqpam2QNsKCqLklyJvC7wJmt3Uer6pQkRwPbgDGggNuBk6rqsSS3Av8BuAXYDFxZVTdM1Zdjjjmmli5dOpRxStJcdPvtt/99VS2cbNv8GerDacB3qurBJCuBt7T6BuBrwCXASmBj9VLt5iRHJTm2td1aVXsAkmwFViT5GnBkVd3c6huBs4ApQ2Tp0qVs27bt4I5OkuawJA/ub9tMPRM5B/hMW15UVQ+35UeARW15MfBQ3z47Wm2q+o5J6pKkGTL0EElyGPAO4HP7bmtXHUO/n5ZkdZJtSbbt3r172KeTpEPGTFyJnAF8vaoebeuPtttUtL+7Wn0ncFzffktabar6kknqz1FV66pqrKrGFi6c9LaeJKmDmQiRc3n2VhbAJmBihtUq4Lq++nltltapwBPtttcWYHmSBW0m13JgS9u2N8mpbVbWeX3HkiTNgKE+WE9yBPA24N195bXAtUkuAB4Ezm71zfRmZo0DPwLOB6iqPUk+DNzW2l028ZAduAj4BHA4vQfqUz5UlyQdXDMyxfeFZGxsrJydJUmDS3J7VY1Nts1PrEuSOjNEJEmdGSKSpM5m6hPrmqWWrrl+JOd9YO3bR3JeSQfGKxFJUmeGiCSpM0NEktSZISJJ6swQkSR1ZohIkjozRCRJnRkikqTODBFJUmeGiCSpM0NEktSZISJJ6swQkSR1ZohIkjozRCRJnRkikqTODBFJUmeGiCSps6GGSJKjknw+ybeS3JfkTUmOTrI1yfb2d0FrmyRXJhlPcleSE/uOs6q1355kVV/9pCR3t32uTJJhjkeS9LOGfSXyUeCvqurVwOuA+4A1wI1VtQy4sa0DnAEsa6/VwFUASY4GLgVOAU4GLp0Intbmwr79Vgx5PJKkPkMLkSQvB34DuBqgqp6uqseBlcCG1mwDcFZbXglsrJ6bgaOSHAucDmytqj1V9RiwFVjRth1ZVTdXVQEb+44lSZoBw7wSOR7YDfzPJHck+XiSI4BFVfVwa/MIsKgtLwYe6tt/R6tNVd8xSV2SNEOGGSLzgROBq6rqDcAPefbWFQDtCqKG2AcAkqxOsi3Jtt27dw/7dJJ0yBhmiOwAdlTVLW398/RC5dF2K4r2d1fbvhM4rm//Ja02VX3JJPXnqKp1VTVWVWMLFy6c1qAkSc8aWohU1SPAQ0le1UqnAfcCm4CJGVargOva8ibgvDZL61TgiXbbawuwPMmC9kB9ObClbdub5NQ2K+u8vmNJkmbA/CEf/3eBTyU5DLgfOJ9ecF2b5ALgQeDs1nYzcCYwDvyotaWq9iT5MHBba3dZVe1pyxcBnwAOB25oL0nSDBlqiFTVncDYJJtOm6RtARfv5zjrgfWT1LcBr5lmNyVJHfmJdUlSZ4aIJKkzQ0SS1JkhIknqzBCRJHVmiEiSOjNEJEmdGSKSpM4MEUlSZ4aIJKkzQ0SS1JkhIknqzBCRJHVmiEiSOjNEJEmdGSKSpM4MEUlSZ4aIJKkzQ0SS1JkhIknqzBCRJHVmiEiSOhtqiCR5IMndSe5Msq3Vjk6yNcn29ndBqyfJlUnGk9yV5MS+46xq7bcnWdVXP6kdf7ztm2GOR5L0s2biSuQ3q+r1VTXW1tcAN1bVMuDGtg5wBrCsvVYDV0EvdIBLgVOAk4FLJ4Kntbmwb78Vwx+OJGnCKG5nrQQ2tOUNwFl99Y3VczNwVJJjgdOBrVW1p6oeA7YCK9q2I6vq5qoqYGPfsSRJM2DYIVLAXye5PcnqVltUVQ+35UeARW15MfBQ3747Wm2q+o5J6s+RZHWSbUm27d69ezrjkST1mT/k47+5qnYm+UVga5Jv9W+sqkpSQ+4DVbUOWAcwNjY29PNJ0qFiqFciVbWz/d0FfIneM41H260o2t9drflO4Li+3Ze02lT1JZPUJUkzZGghkuSIJC+bWAaWA98ENgETM6xWAde15U3AeW2W1qnAE+221xZgeZIF7YH6cmBL27Y3yaltVtZ5fceSJM2AYd7OWgR8qc26nQ98uqr+KsltwLVJLgAeBM5u7TcDZwLjwI+A8wGqak+SDwO3tXaXVdWetnwR8AngcOCG9pIkzZChhUhV3Q+8bpL694HTJqkXcPF+jrUeWD9JfRvwmml3VpLUiZ9YlyR1ZohIkjozRCRJnRkikqTODBFJUmeGiCSpM0NEktSZISJJ6swQkSR1ZohIkjozRCRJnRkikqTODBFJUmeGiCSpM0NEktTZQCGS5J8NuyOSpNln0CuRP0tya5KLkrx8qD2SJM0aA4VIVf068DvAccDtST6d5G1D7Zkk6QVv4GciVbUd+D3gEuCfA1cm+VaSfzmszkmSXtgGfSby2iRXAPcBbwV+q6r+aVu+Yoj9kyS9gM0fsN2fAB8HPlhV/zBRrKrvJfm9ofRMkvSCN+jtrLcDn54IkCQvSvJSgKr65FQ7JpmX5I4kf9nWj09yS5LxJJ9Nclir/1xbH2/bl/Yd4wOt/u0kp/fVV7TaeJI1BzJwSdL0DRoiXwEO71t/aasN4r30boNN+EPgiqr6FeAx4IJWvwB4rNWvaO1IcgJwDvCrwAp6M8XmJZkHfAw4AzgBOLe1lSTNkEFvZ72kqp6cWKmqJyeuRKaSZAm9q5jLgfcnCb3nKP+2NdkAfAi4CljZlgE+D/xpa78SuKaqngK+m2QcOLm1G6+q+9u5rmlt7x1wTHoBW7rm+pGd+4G1bx/ZuaXZZtArkR8mOXFiJclJwD9M0X7CHwP/BfhJW/8F4PGqeqat7wAWt+XFwEMAbfsTrf1P6/vss7+6JGmGDHol8j7gc0m+BwT4J8C/mWqHJP8C2FVVtyd5y7R6OU1JVgOrAV7xileMsiuSNKcMFCJVdVuSVwOvaqVvV9X/e57dfg14R5IzgZcARwIfBY5KMr9dbSwBdrb2O+l9mHFHkvnAy4Hv99Un9O+zv/q+/V8HrAMYGxur5+m3JGlAB/IFjG8EXgucSO8h9nlTNa6qD1TVkqpaSu/B+Fer6neAm4B3tmargOva8qa2Ttv+1aqqVj+nzd46HlgG3ArcBixrs70Oa+fYdADjkSRN00BXIkk+CfwycCfw41YuYGOHc14CXJPkD4A7gKtb/Wrgk+3B+R56oUBV3ZPkWnoPzJ8BLq6qH7d+vQfYAswD1lfVPR36I0nqaNBnImPACe3K4IBV1deAr7Xl+3l2dlV/m38E/vV+9r+c3gyvfeubgc1d+iRJmr5Bb2d9k97DdEmSfmrQK5FjgHuT3Ao8NVGsqncMpVeSpFlh0BD50DA7IUmanQad4vs3SX4JWFZVX2mfVp833K5Jkl7oBv0q+AvpfRXJn7fSYuDLw+qUJGl2GPTB+sX0Pjy4F376A1W/OKxOSZJmh0FD5KmqenpipX2i3E9+S9IhbtAQ+ZskHwQOb7+t/jngfw+vW5Kk2WDQEFkD7AbuBt5N7wN+/qKhJB3iBp2d9RPgL9pLkiRg8O/O+i6TPAOpqlce9B5JkmaNA/nurAkvofcdV0cf/O5IkmaTgZ6JVNX3+147q+qP6f3srSTpEDbo7awT+1ZfRO/KZNCrGEnSHDVoEPxR3/IzwAPA2Qe9N5KkWWXQ2Vm/OeyOSJJmn0FvZ71/qu1V9ZGD0x1J0mxyILOz3sizv2H+W/R+53z7MDoljdLSNdeP5LwPrHWuimafQUNkCXBiVf0AIMmHgOur6l3D6pgk6YVv0K89WQQ83bf+dKtJkg5hg16JbARuTfKltn4WsGE4XZIkzRaDzs66PMkNwK+30vlVdcfwuiVJmg0GvZ0F8FJgb1V9FNiR5PipGid5SZJbk3wjyT1Jfr/Vj09yS5LxJJ9Nclir/1xbH2/bl/Yd6wOt/u0kp/fVV7TaeJI1BzAWSdJBMOjP414KXAJ8oJVeDPyv59ntKeCtVfU64PXAiiSnAn8IXFFVvwI8BlzQ2l8APNbqV7R2JDkBOAf4VWAF8GdJ5iWZB3wMOAM4ATi3tZUkzZBBr0R+G3gH8EOAqvoe8LKpdqieJ9vqi9urgLfS+7126D1XOastr+TZ5yyfB05Lkla/pqqeqqrvAuPAye01XlX3t19dvKa1lSTNkEFD5OmqKtrXwSc5YpCd2hXDncAuYCvwHeDxqnqmNdkBLG7Li4GHANr2J4Bf6K/vs8/+6pKkGTJoiFyb5M+Bo5JcCHyFAX6gqqp+XFWvp/c5k5OBV3fu6TQkWZ1kW5Jtu3fvHkUXJGlOet7ZWe2W0mfpBcBe4FXAf62qrYOepKoeT3IT8CZ6QTS/XW0sAXa2ZjuB4+g9tJ8PvBz4fl99Qv8++6vve/51wDqAsbGx5/y4liSpm+e9Emm3sTZX1daq+s9V9Z8GCZAkC5Mc1ZYPB94G3AfcBLyzNVsFXNeWN7V12vavtnNvAs5ps7eOB5bR+8qV24BlbbbXYfQevk98LYskaQYM+mHDryd5Y1XddgDHPhbY0GZRvQi4tqr+Msm9wDVJ/gC4A7i6tb8a+GSScWAPvVCgqu5Jci1wL72vob+4qn4MkOQ9wBZgHrC+qu45gP5JkqZp0BA5BXhXkgfozdAKvYuU1+5vh6q6C3jDJPX76T0f2bf+j/R+dneyY10OXD5JfTOwebAhSJIOtilDJMkrqur/AqdP1U6SdGh6viuRL9P79t4Hk3yhqv7VTHRKkjQ7PN+D9fQtv3KYHZEkzT7PFyK1n2VJkp73dtbrkuyld0VyeFuGZx+sHznU3kmSXtCmDJGqmjdTHZEkzT4H8lXwkiT9DENEktSZISJJ6swQkSR1ZohIkjozRCRJnRkikqTODBFJUmeGiCSpM0NEktTZoD9KpRFauub6UXdBkibllYgkqTNDRJLUmSEiSerMEJEkdWaISJI6G1qIJDkuyU1J7k1yT5L3tvrRSbYm2d7+Lmj1JLkyyXiSu5Kc2HesVa399iSr+uonJbm77XNlkjy3J5KkYRnmlcgzwH+sqhOAU4GLk5wArAFurKplwI1tHeAMYFl7rQaugl7oAJcCpwAnA5dOBE9rc2HffiuGOB5J0j6GFiJV9XBVfb0t/wC4D1gMrAQ2tGYbgLPa8kpgY/XcDByV5FjgdGBrVe2pqseArcCKtu3Iqrq5qgrY2HcsSdIMmJFnIkmWAm8AbgEWVdXDbdMjwKK2vBh4qG+3Ha02VX3HJPXJzr86ybYk23bv3j2tsUiSnjX0EEny88AXgPdV1d7+be0Koobdh6paV1VjVTW2cOHCYZ9Okg4ZQw2RJC+mFyCfqqovtvKj7VYU7e+uVt8JHNe3+5JWm6q+ZJK6JGmGDHN2VoCrgfuq6iN9mzYBEzOsVgHX9dXPa7O0TgWeaLe9tgDLkyxoD9SXA1vatr1JTm3nOq/vWJKkGTDML2D8NeDfAXcnubPVPgisBa5NcgHwIHB227YZOBMYB34EnA9QVXuSfBi4rbW7rKr2tOWLgE8AhwM3tJckaYYMLUSq6u+A/X1u47RJ2hdw8X6OtR5YP0l9G/CaaXRTkjQNfmJdktSZISJJ6swQkSR1ZohIkjozRCRJnRkikqTODBFJUmeGiCSpM0NEktSZISJJ6swQkSR1ZohIkjozRCRJnRkikqTODBFJUmeGiCSpM0NEktSZISJJ6swQkSR1ZohIkjozRCRJnQ0tRJKsT7IryTf7akcn2Zpke/u7oNWT5Mok40nuSnJi3z6rWvvtSVb11U9Kcnfb58okGdZYJEmTmz/EY38C+FNgY19tDXBjVa1NsqatXwKcASxrr1OAq4BTkhwNXAqMAQXcnmRTVT3W2lwI3AJsBlYANwxxPNJQLV1z/UjO+8Dat4/kvJobhnYlUlV/C+zZp7wS2NCWNwBn9dU3Vs/NwFFJjgVOB7ZW1Z4WHFuBFW3bkVV1c1UVvaA6C0nSjJrpZyKLqurhtvwIsKgtLwYe6mu3o9Wmqu+YpC5JmkEje7DeriBqJs6VZHWSbUm27d69eyZOKUmHhJkOkUfbrSja312tvhM4rq/dklabqr5kkvqkqmpdVY1V1djChQunPQhJUs9Mh8gmYGKG1Srgur76eW2W1qnAE+221xZgeZIFbSbXcmBL27Y3yaltVtZ5fceSJM2Qoc3OSvIZ4C3AMUl20JtltRa4NskFwIPA2a35ZuBMYBz4EXA+QFXtSfJh4LbW7rKqmnhYfxG9GWCH05uV5cwsSZphQwuRqjp3P5tOm6RtARfv5zjrgfWT1LcBr5lOHyVJ0+Mn1iVJnRkikqTODBFJUmeGiCSpM0NEktSZISJJ6swQkSR1ZohIkjozRCRJnRkikqTODBFJUmeGiCSpM0NEktSZISJJ6swQkSR1ZohIkjozRCRJnRkikqTODBFJUmeGiCSps/mj7oCk0Vq65vqRnfuBtW8f2bl1cHglIknqbNZfiSRZAXwUmAd8vKrWDutco/w/NmkuGtW/Ka+ADp5ZfSWSZB7wMeAM4ATg3CQnjLZXknTomNUhApwMjFfV/VX1NHANsHLEfZKkQ8Zsv521GHiob30HcMqI+iJplnAywcEz20NkIElWA6vb6pNJvj3K/kziGODvR92JIZvrY3R8s9+MjDF/OOwz7Nd0xvdL+9sw20NkJ3Bc3/qSVvsZVbUOWDdTnTpQSbZV1dio+zFMc32Mjm/2m+tjHNb4ZvszkduAZUmOT3IYcA6wacR9kqRDxqy+EqmqZ5K8B9hCb4rv+qq6Z8TdkqRDxqwOEYCq2gxsHnU/pukFe6vtIJrrY3R8s99cH+NQxpeqGsZxJUmHgNn+TESSNEKGyIgleSDJ3UnuTLJt1P05GJKsT7IryTf7akcn2Zpke/u7YJR9nI79jO9DSXa29/HOJGeOso/TkeS4JDcluTfJPUne2+pz4j2cYnxz6T18SZJbk3yjjfH3W/34JLckGU/y2TYhaXrn8nbWaCV5ABirqjkzBz/JbwBPAhur6jWt9t+APVW1NskaYEFVXTLKfna1n/F9CHiyqv77KPt2MCQ5Fji2qr6e5GXA7cBZwL9nDryHU4zvbObOexjgiKp6MsmLgb8D3gu8H/hiVV2T5H8A36iqq6ZzLq9EdNBV1d8Ce/YprwQ2tOUN9P7Rzkr7Gd+cUVUPV9XX2/IPgPvofTvEnHgPpxjfnFE9T7bVF7dXAW8FPt/qB+U9NERGr4C/TnJ7+2T9XLWoqh5uy48Ai0bZmSF5T5K72u2uWXmrZ19JlgJvAG5hDr6H+4wP5tB7mGRekjuBXcBW4DvA41X1TGuyg4MQnobI6L25qk6k903EF7dbJXNa9e6hzrX7qFcBvwy8HngY+KPRdmf6kvw88AXgfVW1t3/bXHgPJxnfnHoPq+rHVfV6et/kcTLw6mGcxxAZsara2f7uAr5E782eix5t96In7knvGnF/DqqqerT9o/0J8BfM8vex3Uf/AvCpqvpiK8+Z93Cy8c2193BCVT0O3AS8CTgqycTnAyf9mqgDZYiMUJIj2oM9khwBLAe+OfVes9YmYFVbXgVcN8K+HHQT/3FtfptZ/D62h7JXA/dV1Uf6Ns2J93B/45tj7+HCJEe15cOBt9F79nMT8M7W7KC8h87OGqEkr6R39QG9bw/4dFVdPsIuHRRJPgO8hd63hj4KXAp8GbgWeAXwIHB2Vc3Kh9P7Gd9b6N0GKeAB4N19zw9mlSRvBv4PcDfwk1b+IL3nBrP+PZxifOcyd97D19J7cD6P3sXCtVV1WftvzjXA0cAdwLuq6qlpncsQkSR15e0sSVJnhogkqTNDRJLUmSEiSerMEJEkdWaISJI6M0QkSZ0ZIpKkzv4/2LyLCkd/AwYAAAAASUVORK5CYII=\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# total_lines value of 20 in %\n",
        "np.percentile(train_df.total_lines, 98)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "T_ncJOpYwJbN",
        "outputId": "715bfcfe-2785-408a-d8b7-df129d401f7a"
      },
      "execution_count": 71,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "20.0"
            ]
          },
          "metadata": {},
          "execution_count": 71
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# One hot encoding the total_lines\n",
        "train_total_lines_one_hot = tf.one_hot(train_df[\"total_lines\"].to_numpy(), depth=20)\n",
        "val_total_lines_one_hot = tf.one_hot(val_df[\"total_lines\"].to_numpy(), depth=20)\n",
        "test_total_lines_one_hot = tf.one_hot(test_df[\"total_lines\"].to_numpy(), depth=20)\n",
        "train_total_lines_one_hot.shape, train_total_lines_one_hot[:15]"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0V9CFWkdwhTF",
        "outputId": "fa8d9de3-d112-4cf7-94d6-bf1bef146cda"
      },
      "execution_count": 72,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(TensorShape([180040, 20]), <tf.Tensor: shape=(15, 20), dtype=float32, numpy=\n",
              " array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n",
              "         0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n",
              "         0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n",
              "         0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n",
              "         0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n",
              "         0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n",
              "         0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n",
              "         0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n",
              "         0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n",
              "         0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n",
              "         0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n",
              "         0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n",
              "         0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,\n",
              "         0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,\n",
              "         0., 0., 0., 0.],\n",
              "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,\n",
              "         0., 0., 0., 0.]], dtype=float32)>)"
            ]
          },
          "metadata": {},
          "execution_count": 72
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Modelling model_6\n",
        "\n",
        "# token\n",
        "token_inputs = layers.Input(shape=[], dtype=\"string\", name=\"token_inputs\")\n",
        "token_embeddings = tf_hub_embedding_layer(token_inputs)\n",
        "token_outputs = layers.Dense(128, activation=\"relu\")(token_embeddings)\n",
        "token_model = tf.keras.Model(inputs=token_inputs,\n",
        "                             outputs=token_outputs)\n",
        "# character\n",
        "character_inputs = layers.Input(shape=(1,), dtype=\"string\", name=\"character_inputs\")\n",
        "character_vectors = character_vectorizer(character_inputs)\n",
        "character_embeddings = character_embed(character_vectors)\n",
        "character_biLSTM = layers.Bidirectional(layers.LSTM(32))(character_embeddings)\n",
        "character_model = tf.keras.Model(inputs=character_inputs,\n",
        "                                 outputs=character_biLSTM)\n",
        "# line numbers\n",
        "line_number_inputs = layers.Input(shape=(15,), dtype=tf.int32, name=\"line_number_input\")\n",
        "x = layers.Dense(32, activation=\"relu\")(line_number_inputs)\n",
        "line_number_model = tf.keras.Model(inputs=line_number_inputs,\n",
        "                                   outputs=x)\n",
        "# total lines\n",
        "total_lines_inputs = layers.Input(shape=(20,), dtype=tf.int32, name=\"total_lines_input\")\n",
        "y = layers.Dense(32, activation=\"relu\")(total_lines_inputs)\n",
        "total_line_model = tf.keras.Model(inputs=total_lines_inputs,\n",
        "                                   outputs=y)\n",
        "# Concatenate token & character embeddings \n",
        "combined_embeddings = layers.Concatenate(name=\"token_character_hybrid_embedding\")([token_model.output,\n",
        "                                                                                   character_model.output])\n",
        "z = layers.Dense(256, activation=\"relu\")(combined_embeddings)\n",
        "z = layers.Dropout(0.5)(z)\n",
        "# Concatenate positional embeddings with token & character embeddings\n",
        "z = layers.Concatenate(name=\"token_character_positional_embedding\")([line_number_model.output,\n",
        "                                                                     total_line_model.output,\n",
        "                                                                     z])\n",
        "# Output layer\n",
        "output_layer = layers.Dense(5, activation=\"softmax\", name=\"output_layer\")(z)\n",
        "\n",
        "# model_6\n",
        "model_6 = tf.keras.Model(inputs=[line_number_model.input,\n",
        "                                 total_line_model.input,\n",
        "                                 token_model.input,\n",
        "                                 character_model.input],\n",
        "                         outputs=output_layer)\n",
        "\n",
        "model_6.compile(loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.2),\n",
        "                optimizer=tf.keras.optimizers.Adam(),\n",
        "                metrics=[\"accuracy\"])\n",
        "\n",
        "# training & validation dataset for x, y, z\n",
        "train_pos_character_token_data = tf.data.Dataset.from_tensor_slices((train_line_numbers_one_hot,\n",
        "                                                                     train_total_lines_one_hot,\n",
        "                                                                     train_sentences,\n",
        "                                                                     train_char))\n",
        "train_pos_character_token_labels = tf.data.Dataset.from_tensor_slices(train_labels_one_hot)\n",
        "train_pos_character_token_dataset = tf.data.Dataset.zip((train_pos_character_token_data, train_pos_character_token_labels))\n",
        "train_pos_character_token_dataset = train_pos_character_token_dataset.batch(32).prefetch(tf.data.AUTOTUNE)\n",
        "\n",
        "# validation_dataset\n",
        "val_pos_character_token_data = tf.data.Dataset.from_tensor_slices((val_line_numbers_one_hot,\n",
        "                                                                   val_total_lines_one_hot,\n",
        "                                                                   val_sentences,\n",
        "                                                                   val_char))\n",
        "val_pos_character_token_labels = tf.data.Dataset.from_tensor_slices(val_labels_one_hot)\n",
        "val_pos_character_token_dataset = tf.data.Dataset.zip((val_pos_character_token_data, val_pos_character_token_labels))\n",
        "val_pos_character_token_dataset = val_pos_character_token_dataset.batch(32).prefetch(tf.data.AUTOTUNE)\n",
        "\n",
        "# Fit \n",
        "model_6_history = model_6.fit(train_pos_character_token_dataset,\n",
        "                              steps_per_epoch=int(0.1 * len(train_pos_character_token_dataset)),\n",
        "                              epochs=3,\n",
        "                              validation_data=val_pos_character_token_dataset,\n",
        "                              validation_steps=int(0.1 * len(val_pos_character_token_dataset)))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fQ2SKwC5xarE",
        "outputId": "cd317fa5-1d8f-4260-c135-f5e8d13f9cc5"
      },
      "execution_count": 80,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch 1/3\n",
            "562/562 [==============================] - 130s 222ms/step - loss: 1.1002 - accuracy: 0.7239 - val_loss: 0.9839 - val_accuracy: 0.8049\n",
            "Epoch 2/3\n",
            "562/562 [==============================] - 129s 229ms/step - loss: 0.9679 - accuracy: 0.8153 - val_loss: 0.9501 - val_accuracy: 0.8291\n",
            "Epoch 3/3\n",
            "562/562 [==============================] - 130s 232ms/step - loss: 0.9510 - accuracy: 0.8236 - val_loss: 0.9377 - val_accuracy: 0.8334\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "model_6 has 83.3% accuracy "
      ],
      "metadata": {
        "id": "mpojCiyv9NRb"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "model_6.summary()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "y26rC6FS_Ib6",
        "outputId": "cd6cf1df-8c26-419d-8de5-bf6a50f7c2fe"
      },
      "execution_count": 81,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Model: \"model_30\"\n",
            "__________________________________________________________________________________________________\n",
            " Layer (type)                   Output Shape         Param #     Connected to                     \n",
            "==================================================================================================\n",
            " character_inputs (InputLayer)  [(None, 1)]          0           []                               \n",
            "                                                                                                  \n",
            " token_inputs (InputLayer)      [(None,)]            0           []                               \n",
            "                                                                                                  \n",
            " character_vectorizer (TextVect  (None, 290)         0           ['character_inputs[0][0]']       \n",
            " orization)                                                                                       \n",
            "                                                                                                  \n",
            " universal_sentence_encoder (Ke  (None, 512)         256797824   ['token_inputs[0][0]']           \n",
            " rasLayer)                                                                                        \n",
            "                                                                                                  \n",
            " character_embed (Embedding)    (None, 290, 25)      1750        ['character_vectorizer[7][0]']   \n",
            "                                                                                                  \n",
            " dense_26 (Dense)               (None, 128)          65664       ['universal_sentence_encoder[7][0\n",
            "                                                                 ]']                              \n",
            "                                                                                                  \n",
            " bidirectional_6 (Bidirectional  (None, 64)          14848       ['character_embed[7][0]']        \n",
            " )                                                                                                \n",
            "                                                                                                  \n",
            " token_character_hybrid_embeddi  (None, 192)         0           ['dense_26[0][0]',               \n",
            " ng (Concatenate)                                                 'bidirectional_6[0][0]']        \n",
            "                                                                                                  \n",
            " line_number_input (InputLayer)  [(None, 15)]        0           []                               \n",
            "                                                                                                  \n",
            " total_lines_input (InputLayer)  [(None, 20)]        0           []                               \n",
            "                                                                                                  \n",
            " dense_29 (Dense)               (None, 256)          49408       ['token_character_hybrid_embeddin\n",
            "                                                                 g[0][0]']                        \n",
            "                                                                                                  \n",
            " dense_27 (Dense)               (None, 32)           512         ['line_number_input[0][0]']      \n",
            "                                                                                                  \n",
            " dense_28 (Dense)               (None, 32)           672         ['total_lines_input[0][0]']      \n",
            "                                                                                                  \n",
            " dropout_6 (Dropout)            (None, 256)          0           ['dense_29[0][0]']               \n",
            "                                                                                                  \n",
            " token_character_positional_emb  (None, 320)         0           ['dense_27[0][0]',               \n",
            " edding (Concatenate)                                             'dense_28[0][0]',               \n",
            "                                                                  'dropout_6[0][0]']              \n",
            "                                                                                                  \n",
            " output_layer (Dense)           (None, 5)            1605        ['token_character_positional_embe\n",
            "                                                                 dding[0][0]']                    \n",
            "                                                                                                  \n",
            "==================================================================================================\n",
            "Total params: 256,932,283\n",
            "Trainable params: 134,459\n",
            "Non-trainable params: 256,797,824\n",
            "__________________________________________________________________________________________________\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# model_6 predictions\n",
        "model_6_pred_probs = model_6.predict(val_pos_character_token_dataset, verbose=1)\n",
        "model_6_preds = tf.argmax(model_6_pred_probs, axis=1)\n",
        "model_6_results = calculate_results(y_true=val_labels_encoded,\n",
        "                                    y_pred=model_6_preds)\n",
        "model_6_results"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hAaYMO9S_Lk5",
        "outputId": "ed843549-f6b8-4e56-ecdf-c12ce50efab9"
      },
      "execution_count": 82,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "945/945 [==============================] - 51s 53ms/step\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'accuracy': 83.36422613531047,\n",
              " 'f1': 0.8324610342106541,\n",
              " 'precision': 0.8325524332528096,\n",
              " 'recall': 0.8336422613531047}"
            ]
          },
          "metadata": {},
          "execution_count": 82
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Comparing results from all models (model_1 to model_6)\n",
        "all_model_results = pd.DataFrame({\"model_1\": model_1_results,\n",
        "                                  \"model_2\": model_2_results,\n",
        "                                  \"model_3\": model_3_results,\n",
        "                                  \"model_4\": model_4_results,\n",
        "                                  \"model_5\": model_5_results,\n",
        "                                  \"model_6\": model_6_results})\n",
        "all_model_results = all_model_results.transpose()\n",
        "all_model_results"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 238
        },
        "id": "Syk9LKMrAg-Z",
        "outputId": "433a1690-0f1e-46d5-9ca4-bb1bdbed3322"
      },
      "execution_count": 83,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "\n",
              "  <div id=\"df-65f5c8ac-0806-4e0f-977f-e9b2ac9ade21\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>accuracy</th>\n",
              "      <th>precision</th>\n",
              "      <th>recall</th>\n",
              "      <th>f1</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>model_1</th>\n",
              "      <td>72.183238</td>\n",
              "      <td>0.718647</td>\n",
              "      <td>0.721832</td>\n",
              "      <td>0.698925</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>model_2</th>\n",
              "      <td>78.677347</td>\n",
              "      <td>0.783403</td>\n",
              "      <td>0.786773</td>\n",
              "      <td>0.784393</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>model_3</th>\n",
              "      <td>71.435191</td>\n",
              "      <td>0.714815</td>\n",
              "      <td>0.714352</td>\n",
              "      <td>0.711263</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>model_4</th>\n",
              "      <td>66.142592</td>\n",
              "      <td>0.656632</td>\n",
              "      <td>0.661426</td>\n",
              "      <td>0.652652</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>model_5</th>\n",
              "      <td>73.421157</td>\n",
              "      <td>0.736113</td>\n",
              "      <td>0.734212</td>\n",
              "      <td>0.731693</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>model_6</th>\n",
              "      <td>83.364226</td>\n",
              "      <td>0.832552</td>\n",
              "      <td>0.833642</td>\n",
              "      <td>0.832461</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-65f5c8ac-0806-4e0f-977f-e9b2ac9ade21')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "        \n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "      \n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-65f5c8ac-0806-4e0f-977f-e9b2ac9ade21 button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-65f5c8ac-0806-4e0f-977f-e9b2ac9ade21');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n",
              "  "
            ],
            "text/plain": [
              "          accuracy  precision    recall        f1\n",
              "model_1  72.183238   0.718647  0.721832  0.698925\n",
              "model_2  78.677347   0.783403  0.786773  0.784393\n",
              "model_3  71.435191   0.714815  0.714352  0.711263\n",
              "model_4  66.142592   0.656632  0.661426  0.652652\n",
              "model_5  73.421157   0.736113  0.734212  0.731693\n",
              "model_6  83.364226   0.832552  0.833642  0.832461"
            ]
          },
          "metadata": {},
          "execution_count": 83
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "all_model_results[\"accuracy\"] = all_model_results[\"accuracy\"]/100"
      ],
      "metadata": {
        "id": "1IbWuhxrB4JA"
      },
      "execution_count": 84,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "all_model_results.plot(kind=\"bar\", figsize=(12, 8)).legend(bbox_to_anchor=(1.0, 1.0));"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 515
        },
        "id": "bgXpQwdcCD2x",
        "outputId": "e06021ee-aea4-412a-c64a-d4212637034f"
      },
      "execution_count": 85,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAxkAAAHyCAYAAACH/wOXAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3de7RfZXkv+u+ThBARpAKLCAQMSG6LO0baesMWVKgVvHAqaL2dthk6yq71zj52cyhVq63WblrPGeCleloo21IvUWjZtSqMvdVKQAMkITZCykXAqAgqpSTkPX9kxa6mwaws32Rlks9njAx+c853zvn88owkfNc7L9VaCwAAQC/TproAAADgsUXIAAAAuhIyAACAroQMAACgKyEDAADoSsgAAAC6mjFVJz7ggAPa3Llzp+r0AADsJq6//vrvttZGprqO3cmUhYy5c+dm2bJlU3V6AAB2E1X1L1Ndw+7G5VIAAEBXQgYAANCVkAEAAHQ1ZfdkAADAVLn++usPnDFjxoeTHB0/eN9eG5PcvGHDht986lOf+p2tDRAyAADY7cyYMePDT3rSkxaNjIzcN23atDbV9QzJxo0ba926daP33HPPh5OcsbUxUhsAALujo0dGRh4QMLbftGnT2sjIyP3ZNAu09TE7sR4AANhVTBMwJm/s9+5Rs4SQAQAAdOWeDAAAdntzz7vyqT2Pt/Y9L7i+5/Ema/369dljjz12+nnNZAAAwBQ49dRTn3LUUUctOvLII4963/ved0CSXHHFFU8YHR1dtGDBgtFf/MVfnJ8k999//7Szzjpr7vz580fnz58/+rGPfeznkmSvvfY6YfOx/uIv/uKJL33pS+cmyUtf+tK5L3/5yw879thjF77+9a+f88UvfnGv448/fuGiRYtGTzjhhIXLly/fM0k2bNiQJUuWzJk3b95R8+fPH33Xu9514NKlS/c59dRTn7L5uJ/61Kee8NznPvcp2U5mMgAAYApceumla2fPnv3Ij370ozrhhBNGX/ayl/3g3HPPnfulL33ploULFz587733Tk+S884776AnPOEJj3zzm99cmSTr1q2bvq1j33333TNvuOGGW2bMmJHvf//706677rpb9thjj3z605/e521ve9ucq6+++lvvf//7R26//faZK1euXLHHHnvk3nvvnT4yMvLIG97whsO+/e1vzzj44IM3fPSjH93/ta997Xe397sJGQAAMAXe+973zr7yyit/LknuueeePS666KKRk0466YcLFy58OElmz579SJJce+21T7j88stv3bzfyMjII9s69kte8pL7ZszY9L/63//+96e/7GUvO3zt2rWzqqqtX7++kuQLX/jCE173utet23w51ebz/dqv/dr3PvShD+3327/929+74YYb9v7kJz952/Z+NyEDAAB2ss997nP7XHPNNfssW7bsln322WfjSSedtOCEE054cPXq1bMmeoyq+snnf/3Xf63x2/bee++Nmz+//e1vP+Tkk0/+4T/8wz98a/Xq1TN/+Zd/ecFPO+7rX//6773gBS84ctasWe2FL3zhfZO5p8M9GQAAsJP94Ac/mL7vvvs+ss8++2z8+te/Pmv58uWPf+ihh6Z97Wtf2+eWW26ZmSSbL5c6+eSTH/jABz5w4OZ9N18utf/++6+/4YYbZj3yyCP5zGc+88RHO9cDDzwwfc6cOQ8nycUXX3zA5vWnnHLKAxdffPEB69evz/jzzZ07d/3s2bPXv//97z9oyZIl232pVCJkAADATvfSl770/g0bNtQRRxxx1Fvf+tZDjjvuuB8feOCBGy666KK1L37xi49csGDB6Itf/OIjkuQP//AP7/7BD34wfd68eUctWLBg9KqrrtonSX7/93//rjPPPPPIE088ceHs2bPXP9q53v72t99zwQUXzFm0aNHohg0bfrL+jW9847o5c+Y8vHDhwqMWLFgw+pGPfGS/zdvOPvvs7x100EEPn3jiiQ9N5vtVa1PzDpLFixe3ZcuWTcm5AQDYfVTV9a21xePXLV++fO1xxx03qZ/S7w5e9apXHXbCCSc8+MY3vvFRf4+WL19+wHHHHTd3a9vckwEAAPzEUUcdtehxj3vcxosvvviOyR5DyAAAAH5ixYoVq37WYwgZAAC7mLnnXTmp/da+5wWT2u+Yjx8zqf1uevVNk9qPxz4hAwDgseKCfSe33+GHTWq3VQsXTWq/Rbf8zD8oZxfn6VIAAEBXQgYAANCVkAEAAI8R11577V6vec1rDn207WvXrt3jtNNOO2JH1+GeDAAAuGDfp/Y93v3X9zjMhg0bMmPGxP+X/dnPfvaDz372sx98tO1z585d//d///e39qjtpzGTAQAAU2D16tUzDz/88KPOOOOMw4844oijTjvttCN++MMfTjvkkEOOef3rX3/I6Ojooo9+9KNP/OQnP/mE448/fuHo6Oii008//Yj7779/WpJcc801e51wwgkLFyxYMHrMMccsuu+++6Z97nOf2+eXfumXjkySK6+8cu+FCxeOLly4cHTRokWj991337TVq1fPnDdv3lFJ8uCDD9ZZZ501d/78+aOLFi0a/exnP7tPklx00UX7P+95z3vKs571rHlPfvKTj37d6143Z3u/m5ABAABTZO3atbPOPffc79x6660r9tlnn41//Md/PJIk+++//4aVK1eueuELX/jDd7/73Qdde+2131y5cuWqE0888cE/+IM/mP3QQw/VK17xiqf86Z/+6e2rV69eec0116zee++9N44/9vvf//4nXXTRRf9yyy23rPzqV796y5bb3/ve9x5YVfnmN7+58rLLLrt1yZIlcx988MFKkpUrV+716U9/+tZVq1atWLp06RPXrFmzx/Z8LyEDAACmyJOe9KSHn/e85/04SV75yld+78tf/vLeSfKqV73qviT50pe+9Phvfetbs0466aSFCxcuHL388sv3v/3222feeOONsw488MD1J5988oNJst9++23cY4//mAN+4Rd+4UdvectbDn3nO9954He/+93pW27/8pe/vPcrX/nK7yXJCSec8NDBBx/88E033TQrSZ75zGc+sP/++z+y1157tSOPPPKhb33rW3tuz/dyTwYAAEyRqtrq8j777LMxSVpreeYzn/nAZz/72dvGj/va1772uG0d+93vfvc9L3rRi+7/zGc+s++znvWshVdeeeU/77XXXhu3tV+SzJw5s23+PH369LZ+/fr6aeO3ZCYDAACmyN133z3z85///OOT5NJLL93v6U9/+o/Gb3/Oc57z42XLlu19880375kkDzzwwLQbb7xxz2OPPfah73znO3tcc801eyXJfffdN239+vX/4dgrVqzY86STTvrXd73rXfcce+yxP7755ptnjd/+jGc840d/9Vd/tV+S3HjjjXvefffdM4899tiHenwvIQMAAKbI3LlzH/qzP/uzA4844oijfvCDH8x4y1vesm789oMPPnjDxRdfvPbss88+Yv78+aOLFy9eeNNNN82aNWtWu/TSS7/1O7/zO4ctWLBg9DnPec78Bx988D/8v/0f/dEfHThv3ryj5s+fP7rHHnu0s8466/7x29/2trd9Z+PGjTV//vzRl73sZU+5+OKL1z7ucY9r6aBa63Kc7bZ48eK2bNmyKTk3AMCubO55V05qv7WzXj6p/Y45/LBJ7feJP9wwqf0W3bJqUvtNVlVd31pbPH7d8uXL1x533HHf3amFbGH16tUzf/VXf3XeP//zP6+Yyjoma/ny5Qccd9xxc7e2zUwGAADQ1YRCRlWdVlWrq2pNVZ23le2HVdUXq+rrVXVjVf1K/1IBAOCxY8GCBQ8PdRZjW7YZMqpqepIPJjk9yWiSc6pqdIthv5fkE621E5KcneT/6V0oAAAwDBOZyTgpyZrW2q2ttYeTXJ7kzC3GtCRPGPu8b5Jv9ysRAAAYkomEjEOS3DFu+c6xdeNdkOTXq+rOJFcl+S9bO1BVLamqZVW1bN26dVsbAgAADFyvG7/PSfKx1tqcJL+S5C+r6j8du7V2SWttcWtt8cjISKdTAwAAu5KJhIy7khw6bnnO2LrxfiPJJ5KktfaVJLOSHNCjQAAAYGIuuuii/V/1qlcdliRvetObDj7//PNnT0UdMyYw5rok86rq8GwKF2cn2fIhzLcnOSXJx6pqUTaFDNdDAQAwCMd8/Jin9jzeTa++6frtGb9x48a01jJ9+vSeZUyZbc5ktNY2JDk3ydVJVmXTU6RWVNWFVXXG2LA3J/mtqlqe5K+TvKZN1Vv+AABgAFavXj1z7ty5R7/4xS+eO3/+/KPe9ra3HXT00Ucvmj9//ugb3/jGgzeP+/M///P958+fP7pgwYLRF73oRYcnyWWXXbbvscceu3DRokWjT3/60+ffcccdE5k82GkmVExr7apsuqF7/Lrzx31emeQZfUsDAIDHtttvv33Pj3zkI7fdf//93/+bv/mbJ954442rWms59dRTj/y7v/u7vUdGRja8733vO+grX/nKLQcddNCGe++9d3qSPPe5z/3R2Weffcu0adPyJ3/yJwdceOGFT/rQhz5051R/n812qcQDAAC7k4MOOujhU0455cdLliyZc+211z5hdHR0NEkefPDBabfccsusG264YdoLX/jC+w466KANSTJ79uxHkuS2226b+aIXvWjOunXr9nj44YenHXroof82ld9jS0IGgzD3vCsntd/a97xgu/c55uPHTOpcN736pkntBwDsvvbaa6+NSdJay+/+7u/e/da3vvW747e/613vOnBr+5177rmHveENb7jnFa94xf2f+9zn9rnwwgsP3tq4qSJk8Nh2wb7bv8/hh03qVKsWLprUfotuWTWp/QCAx47TTz/9gQsuuODgJUuWfH/ffffdeNttt+0xc+bM9vznP/+Bs84668h3vOMd9zzpSU965N57750+e/bsR374wx9OP+yww9Ynycc+9rH9p7r+LQkZAAAwxV7ykpc8sGLFillPe9rTFiabZjguvfTS2xYvXvzQm9/85ruf9axnLZw2bVo7+uijH/zbv/3bte94xzu+fc455zxl33333fDMZz7zh7fffvueU/0dxqupegjU4sWL27Jly6bk3AzPpC+XmrXl05a37ZhJzmR84g83TGo/MxkAbGln/ruXPPb/7auq61tri8evW758+drjjjvuu4+2D9u2fPnyA4477ri5W9vW643fAAAASYQMAACgMyEDAADoSsgAAAC6EjIAAICuhAwAAKArIQMAAKbAO9/5zgOPOOKIo57//Oc/5fjjj184c+bME88///zZU11XD17GBwDAbm/VwkVP7Xm8Rbesun5bYz7ykY+MfP7zn//mrFmz2po1a2ZeccUVT+xZw1QykwEAADvZy1/+8sPuvPPOPU8//fR5H/7wh/c7+eSTH9xjjz2m5i3ZO4CZDAAA2Mkuu+yy26+55pp9r7nmmm8edNBBk3t1+i7MTAYAANCVkAEAAHQlZAAAAF25JwMAAKbQ7bffPuNpT3va6I9//OPpVdUuvvji2atWrbp5v/322zjVtU2WkAEAwG5vIo+c7e2uu+66afPne++998adff4dyeVSAABAV0IGAADQlZABAAB0JWQAALA72rhx48aa6iKGauz37lFvTBcyAADYHd28bt26fQWN7bdx48Zat27dvklufrQxni4FAMBuZ8OGDb95zz33fPiee+45On7wvr02Jrl5w4YNv/loA4QMAODfXbDvJPe7v28dsIM99alP/U6SM6a6jscqIQMAHoPmnnflpPZbO2ty5zvm48dMar+bXn3TtgcBgyNkAABTZtXCRZPab9EtqzpXAvTk+jMAAKArIQMAAOhKyAAAALoSMgAAgK52mxu/J/2Ujfe8YFL7ecoGAAC7q90mZEzaZJ8XfvhhfesAAICBcLkUAADQlZkMYNfkrcMAMFhCBrBDeeswAOx+hAyAeOswAPTkngwAAKCrCYWMqjqtqlZX1ZqqOm8r2z9QVd8Y+/XNqvpB/1IBAIAh2OblUlU1PckHkzw3yZ1Jrquqpa21lZvHtNbeOG78f0lywg6oFQAAGICJzGSclGRNa+3W1trDSS5PcuZPGX9Okr/uURwAADA8EwkZhyS5Y9zynWPr/pOqenKSw5N84WcvDQAAGKLeT5c6O8kVrbVHtraxqpYkWZIkhx3mjdhb4wk3AAAM3URmMu5Kcui45Tlj67bm7PyUS6Vaa5e01ha31haPjIxMvEoAAGAwJhIyrksyr6oOr6qZ2RQklm45qKoWJnlikq/0LREAABiSbYaM1tqGJOcmuTrJqiSfaK2tqKoLq+qMcUPPTnJ5a63tmFIBAIAhmNA9Ga21q5JctcW687dYvqBfWQAAwFB54zcAANCVkAEAAHQlZAAAAF0JGQAAQFdCBgAA0JWQAQAAdCVkAAAAXQkZAABAV0IGAADQlZABAAB0JWQAAABdCRkAAEBXQgYAANCVkAEAAHQ1Y6oLAGDXNfe8Kye139r3vGBS+x3z8WMmtd9Nr75pUvsBsGOYyQAAALoykwFAfxfsO7n9Dj+sbx0ATAkzGQAAQFdmMgAYvFULF01qv0W3rOpcCQCJmQwAAKAzIQMAAOhKyAAAALoSMgAAgK6EDAAAoCshAwAA6ErIAAAAuhIyAACAroQMAACgKyEDAADoSsgAAAC6EjIAAICuhAwAAKArIQMAAOhKyAAAALoSMgAAgK6EDAAAoCshAwAA6ErIAAAAuhIyAACAroQMAACgKyEDAADoakIho6pOq6rVVbWmqs57lDG/VlUrq2pFVV3Wt0wAAGAoZmxrQFVNT/LBJM9NcmeS66pqaWtt5bgx85L81yTPaK3dV1UH7qiCAQCAXdtEZjJOSrKmtXZra+3hJJcnOXOLMb+V5IOttfuSpLX2nb5lAgAAQzGRkHFIkjvGLd85tm68+UnmV9X/rqqvVtVpWztQVS2pqmVVtWzdunWTqxgAANil9brxe0aSeUmek+ScJB+qqp/bclBr7ZLW2uLW2uKRkZFOpwYAAHYlEwkZdyU5dNzynLF1492ZZGlrbX1r7bYk38ym0AEAAOxmJhIyrksyr6oOr6qZSc5OsnSLMZ/OplmMVNUB2XT51K0d6wQAAAZimyGjtbYhyblJrk6yKsknWmsrqurCqjpjbNjVSb5XVSuTfDHJW1tr39tRRQMAALuubT7CNklaa1cluWqLdeeP+9ySvGnsFwAAsBvzxm8AAKArIQMAAOhKyAAAALoSMgAAgK6EDAAAoCshAwAA6ErIAAAAuhIyAACAroQMAACgKyEDAADoSsgAAAC6EjIAAICuhAwAAKArIQMAAOhKyAAAALoSMgAAgK6EDAAAoCshAwAA6ErIAAAAuhIyAACAroQMAACgKyEDAADoSsgAAAC6EjIAAICuhAwAAKArIQMAAOhKyAAAALoSMgAAgK6EDAAAoCshAwAA6ErIAAAAuhIyAACAroQMAACgKyEDAADoSsgAAAC6EjIAAICuhAwAAKArIQMAAOhKyAAAALoSMgAAgK6EDAAAoKsJhYyqOq2qVlfVmqo6byvbX1NV66rqG2O/frN/qQAAwBDM2NaAqpqe5INJnpvkziTXVdXS1trKLYb+j9bauTugRgAAYEAmMpNxUpI1rbVbW2sPJ7k8yZk7tiwAAGCoJhIyDklyx7jlO8fWbemlVXVjVV1RVYdu7UBVtaSqllXVsnXr1k2iXAAAYFfX68bvzyaZ21o7Nsk/JPn41ga11i5prS1urS0eGRnpdGoAAGBXMpGQcVeS8TMTc8bW/URr7XuttX8bW/xwkqf2KQ8AABiaiYSM65LMq6rDq2pmkrOTLB0/oKoOGrd4RpJV/UoEAACGZJtPl2qtbaiqc5NcnWR6ko+21lZU1YVJlrXWlib5nao6I8mGJN9P8podWDMAALAL22bISJLW2lVJrtpi3fnjPv/XJP+1b2kAAMAQeeM3AADQlZABAAB0JWQAAABdCRkAAEBXQgYAANCVkAEAAHQlZAAAAF0JGQAAQFdCBgAA0JWQAQAAdCVkAAAAXQkZAABAV0IGAADQlZABAAB0JWQAAABdCRkAAEBXQgYAANCVkAEAAHQlZAAAAF0JGQAAQFdCBgAA0JWQAQAAdCVkAAAAXQkZAABAV0IGAADQlZABAAB0JWQAAABdCRkAAEBXQgYAANCVkAEAAHQlZAAAAF0JGQAAQFdCBgAA0JWQAQAAdCVkAAAAXQkZAABAV0IGAADQlZABAAB0JWQAAABdCRkAAEBXEwoZVXVaVa2uqjVVdd5PGffSqmpVtbhfiQAAwJBsM2RU1fQkH0xyepLRJOdU1ehWxu2T5A1J/ql3kQAAwHBMZCbjpCRrWmu3ttYeTnJ5kjO3Mu4Pkrw3yUMd6wMAAAZmIiHjkCR3jFu+c2zdT1TViUkOba1d2bE2AABggH7mG7+ralqSP0ny5gmMXVJVy6pq2bp1637WUwMAALugiYSMu5IcOm55zti6zfZJcnSSL1XV2iS/kGTp1m7+bq1d0lpb3FpbPDIyMvmqAQCAXdZEQsZ1SeZV1eFVNTPJ2UmWbt7YWru/tXZAa21ua21ukq8mOaO1tmyHVAwAAOzSthkyWmsbkpyb5Ookq5J8orW2oqourKozdnSBAADAsMyYyKDW2lVJrtpi3fmPMvY5P3tZAADAUHnjNwAA0JWQAQAAdCVkAAAAXQkZAABAV0IGAADQlZABAAB0JWQAAABdCRkAAEBXQgYAANCVkAEAAHQlZAAAAF0JGQAAQFdCBgAA0JWQAQAAdCVkAAAAXQkZAABAV0IGAADQlZABAAB0JWQAAABdCRkAAEBXQgYAANCVkAEAAHQlZAAAAF0JGQAAQFdCBgAA0JWQAQAAdCVkAAAAXQkZAABAV0IGAADQlZABAAB0JWQAAABdCRkAAEBXQgYAANCVkAEAAHQlZAAAAF0JGQAAQFdCBgAA0JWQAQAAdCVkAAAAXQkZAABAV0IGAADQ1YRCRlWdVlWrq2pNVZ23le2vq6qbquobVfW/qmq0f6kAAMAQbDNkVNX0JB9McnqS0STnbCVEXNZaO6a1dnySP0ryJ90rBQAABmEiMxknJVnTWru1tfZwksuTnDl+QGvtgXGLj0/S+pUIAAAMyYwJjDkkyR3jlu9M8vNbDqqq307ypiQzk/zy1g5UVUuSLEmSww47bHtrBQAABqDbjd+ttQ+21p6S5O1Jfu9RxlzSWlvcWls8MjLS69QAAMAuZCIh464kh45bnjO27tFcnuRFP0tRAADAcE0kZFyXZF5VHV5VM5OcnWTp+AFVNW/c4guS/HO/EgEAgCHZ5j0ZrbUNVXVukquTTE/y0dbaiqq6MMmy1trSJOdW1alJ1ie5L8mrd2TRAADArmsiN36ntXZVkqu2WHf+uM9v6FwXAAAwUN74DQAAdCVkAAAAXQkZAABAV0IGAADQlZABAAB0JWQAAABdCRkAAEBXQgYAANCVkAEAAHQlZAAAAF0JGQAAQFdCBgAA0JWQAQAAdCVkAAAAXQkZAABAV0IGAADQlZABAAB0JWQAAABdCRkAAEBXQgYAANCVkAEAAHQlZAAAAF0JGQAAQFdCBgAA0JWQAQAAdCVkAAAAXQkZAABAV0IGAADQlZABAAB0JWQAAABdCRkAAEBXQgYAANCVkAEAAHQlZAAAAF0JGQAAQFdCBgAA0JWQAQAAdCVkAAAAXQkZAABAV0IGAADQlZABAAB0NaGQUVWnVdXqqlpTVedtZfubqmplVd1YVf9YVU/uXyoAADAE2wwZVTU9yQeTnJ5kNMk5VTW6xbCvJ1ncWjs2yRVJ/qh3oQAAwDBMZCbjpCRrWmu3ttYeTnJ5kjPHD2itfbG19uDY4leTzOlbJgAAMBQTCRmHJLlj3PKdY+sezW8k+butbaiqJVW1rKqWrVu3buJVAgAAg9H1xu+q+vUki5P88da2t9Yuaa0tbq0tHhkZ6XlqAABgFzFjAmPuSnLouOU5Y+v+g6o6Nck7kpzcWvu3PuUBAABDM5GZjOuSzKuqw6tqZpKzkywdP6CqTkhycZIzWmvf6V8mAAAwFNsMGa21DUnOTXJ1klVJPtFaW1FVF1bVGWPD/jjJ3kn+pqq+UVVLH+VwAADAY9xELpdKa+2qJFdtse78cZ9P7VwXAAAwUN74DQAAdCVkAAAAXQkZAABAV0IGAADQlZABAAB0JWQAAABdCRkAAEBXQgYAANCVkAEAAHQlZAAAAF0JGQAAQFdCBgAA0JWQAQAAdCVkAAAAXQkZAABAV0IGAADQlZABAAB0JWQAAABdCRkAAEBXQgYAANCVkAEAAHQlZAAAAF0JGQAAQFdCBgAA0JWQAQAAdCVkAAAAXQkZAABAV0IGAADQlZABAAB0JWQAAABdCRkAAEBXQgYAANCVkAEAAHQlZAAAAF0JGQAAQFdCBgAA0JWQAQAAdCVkAAAAXQkZAABAV0IGAADQ1YRCRlWdVlWrq2pNVZ23le3PrqobqmpDVZ3Vv0wAAGAothkyqmp6kg8mOT3JaJJzqmp0i2G3J3lNkst6FwgAAAzLjAmMOSnJmtbarUlSVZcnOTPJys0DWmtrx7Zt3AE1AgAAAzKRy6UOSXLHuOU7x9YBAAD8Jzv1xu+qWlJVy6pq2bp163bmqQEAgJ1kIiHjriSHjlueM7Zuu7XWLmmtLW6tLR4ZGZnMIQAAgF3cRELGdUnmVdXhVTUzydlJlu7YsgAAgKHaZshorW1Icm6Sq5OsSvKJ1tqKqrqwqs5Ikqp6WlXdmeT/SHJxVa3YkUUDAAC7rok8XSqttauSXLXFuvPHfb4umy6jAgAAdnPe+A0AAHQlZAAAAF0JGQAAQFdCBgAA0JWQAQAAdCVkAAAAXQkZAABAV0IGAADQlZABAAB0JWQAAABdCRkAAEBXQgYAANCVkAEAAHQlZAAAAF0JGQAAQFdCBgAA0JWQAQAAdCVkAAAAXQkZAABAV0IGAADQlZABAAB0JWQAAABdCRkAAEBXQgYAANCVkAEAAHQlZAAAAF0JGQAAQFdCBgAA0JWQAQAAdCVkAAAAXQkZAABAV0IGAADQlZABAAB0JWQAAABdCRkAAEBXQgYAANCVkAEAAPMxiqcAAAkdSURBVHQlZAAAAF0JGQAAQFdCBgAA0JWQAQAAdDWhkFFVp1XV6qpaU1XnbWX7nlX1P8a2/1NVze1dKAAAMAzbDBlVNT3JB5OcnmQ0yTlVNbrFsN9Icl9r7cgkH0jy3t6FAgAAwzCRmYyTkqxprd3aWns4yeVJztxizJlJPj72+Yokp1RV9SsTAAAYihkTGHNIkjvGLd+Z5OcfbUxrbUNV3Z9k/yTfHT+oqpYkWTK2+KOqWj2ZonemySelmw/IFt9/IracIpowmW6rJve7one7An/2hk3/hkvvhk3/HtWTd/YJd3cTCRndtNYuSXLJzjznVKmqZa21xVNdB9tP74ZN/4ZN/4ZL74ZN/+htIpdL3ZXk0HHLc8bWbXVMVc1Ism+S7/UoEAAAGJaJhIzrksyrqsOramaSs5Ms3WLM0iSvHvt8VpIvtNZavzIBAICh2OblUmP3WJyb5Ook05N8tLW2oqouTLKstbY0yUeS/GVVrUny/WwKIru73eKysMcovRs2/Rs2/RsuvRs2/aOrMuEAAAD05I3fAABAV0IGAADQlZABAAB0JWQAAABdCRk7WFW9dqprYNuqamFVnVJVe2+x/rSpqomJqaqTquppY59Hq+pNVfUrU10Xk1NV/99U18DkVNUzx/78PW+qa+Gnq6qfr6onjH1+XFX9flV9tqreW1X7TnV9PDZ4utQOVlW3t9YOm+o6eHRV9TtJfjvJqiTHJ3lDa+0zY9tuaK2dOJX18eiq6v9Ocno2PY77H5L8fJIvJnlukqtba++awvLYhqra8p1LleSXknwhSVprZ+z0opiwqvpaa+2ksc+/lU1/j34qyfOSfLa19p6prI9HV1Urkhw39pqCS5I8mOSKJKeMrX/JlBbIY4KQ0UFV3fhom5LMb63tuTPrYftU1U1JfrG19qOqmptNf9H+ZWvtv1fV11trJ0xpgTyqsd4dn2TPJPckmdNae6CqHpfkn1prx05pgfxUVXVDkpVJPpykZdPfmX+dsXcttdaumbrq2Jbxfz9W1XVJfqW1tq6qHp/kq621Y6a2Qh5NVa1qrS0a+/wffphWVd9orR0/ddXxWLHNl/ExIbOTPD/JfVusryRf3vnlsJ2mtdZ+lCSttbVV9ZwkV1TVk7Oph+y6NrTWHknyYFV9q7X2QJK01v61qjZOcW1s2+Ikb0jyjiRvba19o6r+VbgYjGlV9cRsuvS6WmvrkqS19uOq2jC1pbENN1fVa1trf5FkeVUtbq0tq6r5SdZPdXE8NggZfXwuyd6ttW9suaGqvrTzy2E73VtVx2/u39iMxq8m+WgSP4nbtT1cVXu11h5M8tTNK8euKRYydnGttY1JPlBVfzP233vj36Uh2TfJ9dn0w5hWVQe11u4eu7fND2h2bb+Z5L9X1e8l+W6Sr1TVHUnuGNsGPzOXS+1EVfXE1tqWsx1Msaqak00/Eb9nK9ue0Vr732Of9W8XU1V7ttb+bSvrD0hyUGvtprFlvRuAqnpBkme01v6vLdbr34BU1V5JZrfWbhtb1r9d1NjN34dnU7i/s7V27xbb9Y5JEzJ2IjcRD5v+DZfeDZv+DZv+DZfe8bPwCNudy/TxsOnfcOndsOnfsOnfcOkdkyZk7FymjYZN/4ZL74ZN/4ZN/4ZL75g0IQMAAOhKyNi5TDsOm/4Nl94Nm/4Nm/4Nl94xaW787qCq9vtp21tr3988bvNndh36N1x6N2z6N2z6N1x6x84gZHRQVbfl399Wu6XWWjtiJ5fEdtC/4dK7YdO/YdO/4dI7dgYhAwAA6Mo9GR3VJr9eVf9tbPmwqjppqutiYvRvuPRu2PRv2PRvuPSOHclMRkdV9f8m2Zjkl1tri6rqiUn+Z2vtaVNcGhOgf8Old8Omf8Omf8Old+xIM6a6gMeYn2+tnVhVX0+S1tp9VTVzqotiwvRvuPRu2PRv2PRvuPSOHcblUn2tr6rpGXt5TVWNZNNPCBgG/RsuvRs2/Rs2/RsuvWOHETL6uijJp5IcWFXvSvK/krx7aktiO+jfcOndsOnfsOnfcOkdO4x7MjqrqoVJTsmmx8L9Y2tt1RSXxHbQv+HSu2HTv2HTv+HSO3YUIaODib7Uhl2T/g2X3g2b/g2b/g2X3rEzCBkdbPFSm8OS3Df2+eeS3N5aO3wKy2Mb9G+49G7Y9G/Y9G+49I6dwT0ZHbTWDh97O+bnk7ywtXZAa23/JL+a5H9ObXVsi/4Nl94Nm/4Nm/4Nl96xM5jJ6KiqbmqtHbOtdeya9G+49G7Y9G/Y9G+49I4dyXsy+vp2Vf1ekr8aW35Fkm9PYT1sH/0bLr0bNv0bNv0bLr1jh3G5VF/nJBnJpsfBfSrJgWPrGAb9Gy69Gzb9Gzb9Gy69Y4dxudQOUFX7JGmttR9NdS1sP/0bLr0bNv0bNv0bLr1jRzCT0VFVHVNVX09yc5IVVXV9VR091XUxMfo3XHo3bPo3bPo3XHrHjiRk9HVxkje11p7cWntykjcnuWSKa2Li9G+49G7Y9G/Y9G+49I4dRsjo6/GttS9uXmitfSnJ46euHLaT/g2X3g2b/g2b/g2X3rHDeLpUX7dW1X9L8pdjy7+e5NYprIfto3/DpXfDpn/Dpn/DpXfsMGYy+vo/s+kpDX879uuAJK+d0orYHvo3XHo3bPo3bPo3XHrHDiNk9PWUJIdm0+/rzCSnJLl2Sitie+jfcOndsOnfsOnfcOkdO4xH2HZUVauTvCWbntKwcfP61tq/TFlRTJj+DZfeDZv+DZv+DZfesSO5J6Ovda21z051EUya/g2X3g2b/g2b/g2X3rHDmMnoqKpOyaY3Zf5jkn/bvL619skpK4oJ07/h0rth079h07/h0jt2JDMZfb02ycIke+Tfpx1bEn9Yh0H/hkvvhk3/hk3/hkvv2GHMZHRUVatbawumug4mR/+GS++GTf+GTf+GS+/YkTxdqq8vV9XoVBfBpOnfcOndsOnfsOnfcOkdO4yZjI6qalU2PQ7utmy6trGStNbasVNaGBOif8Old8Omf8Omf8Old+xIQkZHVfXkra33KLhh0L/h0rth079h07/h0jt2JCEDAADoyj0ZAABAV0IGAADQlZABAAB0JWQAAABdCRkAAEBX/z+dboWWC7l3SgAAAABJRU5ErkJggg==\n",
            "text/plain": [
              "<Figure size 864x576 with 1 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "all_model_results.sort_values(\"f1\", ascending=False)[\"f1\"].plot(kind=\"bar\", figsize=(12, 8));"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 515
        },
        "id": "19IAPIT7CWYf",
        "outputId": "800109e7-1c6a-4175-e089-9ec3ed4e4348"
      },
      "execution_count": 86,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAsIAAAHyCAYAAADhgKjpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAa50lEQVR4nO3df7Bnd33X8debXYMtpRTNpcNkk+xaQ2GnINAl6OAoNlATUhPHopOMOAWBjDNNywh2XCyNNU4daGfKtDPRYa1oxYE0RdFF1kkrP3TaAu4CKbDJpG5DSjbY9kKjTKUStrz9437TXG5vcr+bnLvfPfk8HjM7fM/5nrn3DW8Iz5x8f1R3BwAARvOkVQ8AAACrIIQBABiSEAYAYEhCGACAIQlhAACGJIQBABjS3lX94gsvvLD379+/ql8PAMAgPvGJT3yxu9e2nl9ZCO/fvz8nTpxY1a8HAGAQVfVb25330ggAAIYkhAEAGJIQBgBgSEIYAIAhCWEAAIYkhAEAGJIQBgBgSEIYAIAhCWEAAIYkhAEAGJIQBgBgSEIYAIAhCWEAAIYkhAEAGJIQBgBgSEIYAIAhCWEAAIYkhAEAGNLeVQ9wruw//IFVj7Cr7n3r1aseAQBgVtwRBgBgSEIYAIAhCWEAAIYkhAEAGJIQBgBgSEIYAIAhCWEAAIYkhAEAGJIQBgBgSEIYAIAhCWEAAIYkhAEAGNJSIVxVV1bV3VV1qqoOb/P8JVX14ar6VFV9uqpeMf2oAAAwnR1DuKr2JLklyVVJDia5vqoObrnsLUlu6+4XJLkuyT+felAAAJjSMneEL09yqrvv6e4Hk9ya5Not13SSb108flqSL0w3IgAATG+ZEL4oyX2bjk8vzm3240leVVWnkxxL8kPb/aCquqGqTlTVifX19ccwLgAATGOqN8tdn+TfdPe+JK9I8q6q+mM/u7uPdPeh7j60trY20a8GAICzt0wI35/k4k3H+xbnNnttktuSpLs/muRPJrlwigEBAGA3LBPCx5NcVlUHquqCbLwZ7uiWaz6f5IokqarnZCOEvfYBAIDz1o4h3N1nktyY5PYkd2Xj0yFOVtXNVXXN4rI3JXl9Vf16kvckeXV3924NDQAAj9feZS7q7mPZeBPc5nM3bXp8Z5KXTDsaAADsHt8sBwDAkIQwAABDWuqlEbBq+w9/YNUj7Jp733r1qkcAgCG5IwwAwJCEMAAAQxLCAAAMSQgDADAkIQwAwJCEMAAAQxLCAAAMSQgDADAkIQwAwJCEMAAAQxLCAAAMSQgDADAkIQwAwJCEMAAAQxLCAAAMSQgDADAkIQwAwJCEMAAAQxLCAAAMSQgDADAkIQwAwJCEMAAAQxLCAAAMSQgDADAkIQwAwJD2rnoA4Ilt/+EPrHqEXXXvW69e9QgAPEbuCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkHzFMgCPyFdkA09k7ggDADCkpUK4qq6sqrur6lRVHd7m+bdX1R2LP79RVf97+lEBAGA6O740oqr2JLklycuTnE5yvKqOdvedD13T3X9/0/U/lOQFuzArAABMZpk7wpcnOdXd93T3g0luTXLto1x/fZL3TDEcAADslmVC+KIk9206Pr0498dU1aVJDiT50CM8f0NVnaiqE+vr62c7KwAATGbqT424Lsl7u/sPt3uyu48kOZIkhw4d6ol/NwCw4BM/YGfL3BG+P8nFm473Lc5t57p4WQQAADOwTAgfT3JZVR2oqguyEbtHt15UVc9O8vQkH512RAAAmN6OIdzdZ5LcmOT2JHclua27T1bVzVV1zaZLr0tya3d7yQMAAOe9pV4j3N3Hkhzbcu6mLcc/Pt1YAACwu3yzHAAAQxLCAAAMSQgDADAkIQwAwJCEMAAAQxLCAAAMSQgDADAkIQwAwJCEMAAAQxLCAAAMSQgDADAkIQwAwJCEMAAAQxLCAAAMSQgDADAkIQwAwJCEMAAAQxLCAAAMae+qBwAA4BvtP/yBVY+wq+5969WrHiGJO8IAAAxKCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMKSlQriqrqyqu6vqVFUdfoRr/lZV3VlVJ6vq3dOOCQAA09q70wVVtSfJLUlenuR0kuNVdbS779x0zWVJ3pzkJd39QFU9Y7cGBgCAKSxzR/jyJKe6+57ufjDJrUmu3XLN65Pc0t0PJEl3/+60YwIAwLSWCeGLkty36fj04txmz0ryrKr61ar6WFVdud0PqqobqupEVZ1YX19/bBMDAMAEpnqz3N4klyV5aZLrk/zLqvq2rRd195HuPtTdh9bW1ib61QAAcPaWCeH7k1y86Xjf4txmp5Mc7e6vdffnkvxGNsIYAADOS8uE8PEkl1XVgaq6IMl1SY5uueY/ZuNucKrqwmy8VOKeCecEAIBJ7RjC3X0myY1Jbk9yV5LbuvtkVd1cVdcsLrs9yZeq6s4kH07yI939pd0aGgAAHq8dPz4tSbr7WJJjW87dtOlxJ3nj4g8AAJz3fLMcAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkJYK4aq6sqrurqpTVXV4m+dfXVXrVXXH4s/rph8VAACms3enC6pqT5Jbkrw8yekkx6vqaHffueXSX+juG3dhRgAAmNwyd4QvT3Kqu+/p7geT3Jrk2t0dCwAAdtcyIXxRkvs2HZ9enNvq+6vq01X13qq6eJLpAABgl0z1Zrn3J9nf3c9L8stJfn67i6rqhqo6UVUn1tfXJ/rVAABw9pYJ4fuTbL7Du29x7o9095e6+6uLw59L8t3b/aDuPtLdh7r70Nra2mOZFwAAJrFMCB9PcllVHaiqC5Jcl+To5guq6pmbDq9Jctd0IwIAwPR2/NSI7j5TVTcmuT3JniTv7O6TVXVzkhPdfTTJD1fVNUnOJPm9JK/exZkBAOBx2zGEk6S7jyU5tuXcTZsevznJm6cdDQAAdo9vlgMAYEhCGACAIQlhAACGJIQBABiSEAYAYEhCGACAIQlhAACGJIQBABiSEAYAYEhCGACAIQlhAACGJIQBABiSEAYAYEhCGACAIQlhAACGJIQBABiSEAYAYEhCGACAIQlhAACGJIQBABiSEAYAYEhCGACAIQlhAACGJIQBABiSEAYAYEhCGACAIQlhAACGJIQBABiSEAYAYEhCGACAIQlhAACGJIQBABiSEAYAYEhCGACAIQlhAACGJIQBABiSEAYAYEhCGACAIQlhAACGJIQBABiSEAYAYEhLhXBVXVlVd1fVqao6/CjXfX9VdVUdmm5EAACY3o4hXFV7ktyS5KokB5NcX1UHt7nuqUnekOTjUw8JAABTW+aO8OVJTnX3Pd39YJJbk1y7zXX/NMnbkvy/CecDAIBdsUwIX5Tkvk3Hpxfn/khVvTDJxd39gUf7QVV1Q1WdqKoT6+vrZz0sAABM5XG/Wa6qnpTkp5O8aadru/tIdx/q7kNra2uP91cDAMBjtkwI35/k4k3H+xbnHvLUJN+V5CNVdW+SP5/kqDfMAQBwPlsmhI8nuayqDlTVBUmuS3L0oSe7+/9094Xdvb+79yf5WJJruvvErkwMAAAT2DGEu/tMkhuT3J7kriS3dffJqrq5qq7Z7QEBAGA37F3mou4+luTYlnM3PcK1L338YwEAwO7yzXIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxpqRCuqiur6u6qOlVVh7d5/u9V1Weq6o6q+pWqOjj9qAAAMJ0dQ7iq9iS5JclVSQ4muX6b0H13dz+3u5+f5CeT/PTkkwIAwISWuSN8eZJT3X1Pdz+Y5NYk126+oLu/vOnwKUl6uhEBAGB6e5e45qIk9206Pp3kxVsvqqofTPLGJBck+Z7tflBV3ZDkhiS55JJLznZWAACYzGRvluvuW7r7O5L8wyRveYRrjnT3oe4+tLa2NtWvBgCAs7ZMCN+f5OJNx/sW5x7JrUn++uMZCgAAdtsyIXw8yWVVdaCqLkhyXZKjmy+oqss2HV6d5H9ONyIAAExvx9cId/eZqroxye1J9iR5Z3efrKqbk5zo7qNJbqyqlyX5WpIHkvzAbg4NAACP1zJvlkt3H0tybMu5mzY9fsPEcwEAwK7yzXIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxJCAMAMCQhDADAkIQwAABDEsIAAAxpqRCuqiur6u6qOlVVh7d5/o1VdWdVfbqqPlhVl04/KgAATGfHEK6qPUluSXJVkoNJrq+qg1su+1SSQ939vCTvTfKTUw8KAABTWuaO8OVJTnX3Pd39YJJbk1y7+YLu/nB3f2Vx+LEk+6YdEwAAprVMCF+U5L5Nx6cX5x7Ja5P8l+2eqKobqupEVZ1YX19ffkoAAJjYpG+Wq6pXJTmU5Ke2e767j3T3oe4+tLa2NuWvBgCAs7J3iWvuT3LxpuN9i3PfoKpeluRHk/zl7v7qNOMBAMDuWOaO8PEkl1XVgaq6IMl1SY5uvqCqXpDkHUmu6e7fnX5MAACY1o4h3N1nktyY5PYkdyW5rbtPVtXNVXXN4rKfSvItSX6xqu6oqqOP8OMAAOC8sMxLI9Ldx5Ic23Lupk2PXzbxXAAAsKt8sxwAAEMSwgAADEkIAwAwJCEMAMCQhDAAAEMSwgAADEkIAwAwJCEMAMCQhDAAAEMSwgAADEkIAwAwJCEMAMCQhDAAAEMSwgAADEkIAwAwJCEMAMCQhDAAAEMSwgAADEkIAwAwJCEMAMCQhDAAAEMSwgAADEkIAwAwJCEMAMCQhDAAAEMSwgAADEkIAwAwJCEMAMCQhDAAAEMSwgAADEkIAwAwJCEMAMCQhDAAAEMSwgAADEkIAwAwJCEMAMCQhDAAAEMSwgAADEkIAwAwJCEMAMCQlgrhqrqyqu6uqlNVdXib5/9SVX2yqs5U1SunHxMAAKa1YwhX1Z4ktyS5KsnBJNdX1cEtl30+yauTvHvqAQEAYDfsXeKay5Oc6u57kqSqbk1ybZI7H7qgu+9dPPf1XZgRAAAmt8xLIy5Kct+m49OLcwAAMFvn9M1yVXVDVZ2oqhPr6+vn8lcDAMA3WCaE709y8abjfYtzZ627j3T3oe4+tLa29lh+BAAATGKZED6e5LKqOlBVFyS5LsnR3R0LAAB2144h3N1nktyY5PYkdyW5rbtPVtXNVXVNklTVi6rqdJK/meQdVXVyN4cGAIDHa5lPjUh3H0tybMu5mzY9Pp6Nl0wAAMAs+GY5AACGJIQBABiSEAYAYEhCGACAIQlhAACGJIQBABiSEAYAYEhCGACAIQlhAACGJIQBABiSEAYAYEhCGACAIQlhAACGJIQBABiSEAYAYEhCGACAIQlhAACGJIQBABiSEAYAYEhCGACAIQlhAACGJIQBABiSEAYAYEhCGACAIQlhAACGJIQBABiSEAYAYEhCGACAIQlhAACGJIQBABiSEAYAYEhCGACAIQlhAACGJIQBABiSEAYAYEhCGACAIQlhAACGJIQBABiSEAYAYEhCGACAIQlhAACGtFQIV9WVVXV3VZ2qqsPbPP/kqvqFxfMfr6r9Uw8KAABT2jGEq2pPkluSXJXkYJLrq+rglstem+SB7v6zSd6e5G1TDwoAAFNa5o7w5UlOdfc93f1gkluTXLvlmmuT/Pzi8XuTXFFVNd2YAAAwreruR7+g6pVJruzu1y2O/06SF3f3jZuu+ezimtOL499cXPPFLT/rhiQ3LA6/M8ndU/0bOQ9dmOSLO17F+cju5s3+5s3+5svu5u2Jvr9Lu3tt68m953KC7j6S5Mi5/J2rUlUnuvvQqufg7NndvNnfvNnffNndvI26v2VeGnF/kos3He9bnNv2mqram+RpSb40xYAAALAblgnh40kuq6oDVXVBkuuSHN1yzdEkP7B4/MokH+qdXnMBAAArtONLI7r7TFXdmOT2JHuSvLO7T1bVzUlOdPfRJP8qybuq6lSS38tGLI9uiJeAPEHZ3bzZ37zZ33zZ3bwNub8d3ywHAABPRL5ZDgCAIQlhAACGJIQBABiSEAYAYEhCeAJV9eKq+tbF42+qqn9SVe+vqrdV1dNWPR87q6pnV9UVVfUtW85fuaqZOHtV9Rer6o1V9b2rnoXlVNXlVfWixeODi/29YtVz8fhU1WtWPQNnp6r+7apnWAWfGjGBqjqZ5M8tPmruSJKvJHlvkisW5//GSgfkUVXVDyf5wSR3JXl+kjd0939aPPfJ7n7hKufjkVXV/+juyxePX5+NPb4vyfcmeX93v3WV8/HoquofJ7kqGx/l+ctJXpzkw0lenuT27v6JFY7H41BVn+/uS1Y9B9urqq3fB1FJ/kqSDyVJd19zzodaESE8gaq6q7ufs3j8DeFUVXd09/NXNx07qarPJPkL3f37VbU/G38T867u/pmq+lR3v2ClA/KINu+nqo4neUV3r1fVU5J8rLufu9oJeTSL/+09P8mTk/x2kn3d/eWq+qYkH+/u5610QB5VVX36kZ5K8qzufvK5nIflVdUnk9yZ5OeSdDZ29p4svgeiu//b6qY7t3b8Qg2W8tmqek13/+skv15Vh7r7RFU9K8nXVj0cO3pSd/9+knT3vVX10iTvrapLs/EXB85fT6qqp2fjZV7V3etJ0t3/t6rOrHY0lnCmu/8wyVeq6je7+8tJ0t1/UFVfX/Fs7Ozbk/zVJA9sOV9Jfu3cj8NZOJTkDUl+NMmPdPcdVfUHIwXwQ4TwNF6X5Geq6i1Jvpjko1V1X5L7Fs9xfvudqnp+d9+RJIs7w9+X5J1J3FE8vz0tySey8X+8XVXP7O7/tXitt7+JOf89WFXf3N1fSfLdD51cvLdCCJ///nOSb3nor52bVdVHzv04LKu7v57k7VX1i4t//Z0M2oReGjGhxRvmDmTjv0ynu/t3tjz/9O7e+nfOrFhV7cvGnanf3ua5l3T3ry4e299MVNU3J/n27v7c4tjuzkNV9eTu/uo25y9M8szu/szi2P5mzP7Of1V1dZKXdPc/2nL+Cb87IXwOeePVvNnffNndvNnfvNnffI2wOx+fdm75R7XzZn/zZXfzZn/zZn/z9YTfnRA+t9x+nzf7my+7mzf7mzf7m68n/O6EMAAAQxLC59YT/h8xPMHZ33zZ3bzZ37zZ33w94XfnzXITqKo/9WjPd/fvPXTdQ485f9jffNndvNnfvNnffNndw4TwBKrqc3n4m1m26u7+M+d4JM6C/c2X3c2b/c2b/c2X3T1MCAMAMCSvEZ5QbXhVVf3Y4viSqrp81XOxHPubL7ubN/ubN/ubL7tzR3hSVfUvsvG1oN/T3c+pqqcn+aXuftGKR2MJ9jdfdjdv9jdv9jdfdjfo90rvohd39wur6lNJ0t0PVNUFqx6KpdnffNndvNnfvNnffA2/Oy+NmNbXqmpPFh9AXVVr2fg7LebB/ubL7ubN/ubN/uZr+N0J4Wn9bJL3JXlGVf1Ekl9J8s9WOxJnwf7my+7mzf7mzf7ma/jdeY3wxKrq2UmuyMZHknywu+9a8UicBfubL7ubN/ubN/ubr9F3J4QnsOwHU3N+sr/5srt5s795s7/5sruHCeEJbPlg6kuSPLB4/G1JPt/dB1Y4Hjuwv/myu3mzv3mzv/myu4d5jfAEuvvA4ltY/muSv9bdF3b3n07yfUl+abXTsRP7my+7mzf7mzf7my+7e5g7whOqqs9093N3Osf5yf7my+7mzf7mzf7my+58jvDUvlBVb0ny7xbHfzvJF1Y4D2fH/ubL7ubN/ubN/uZr+N15acS0rk+ylo2PInlfkmcszjEP9jdfdjdv9jdv9jdfw+/OSyN2QVU9NUl39++vehbOnv3Nl93Nm/3Nm/3N18i7c0d4QlX13MXXFH42ycmq+kRVfdeq52I59jdfdjdv9jdv9jdfdieEp/aOJG/s7ku7+9Ikb0pyZMUzsTz7my+7mzf7mzf7m6/hdyeEp/WU7v7wQwfd/ZEkT1ndOJwl+5svu5s3+5s3+5uv4XfnUyOmdU9V/ViSdy2OX5XknhXOw9mxv/myu3mzv3mzv/kafnfuCE/r72bj3Zf/fvHnwiSvWelEnA37my+7mzf7mzf7m6/hdyeEp/UdSS7Oxn+uFyS5Isl/X+lEnA37my+7mzf7mzf7m6/hd+fj0yZUVXcn+QfZePfl1x86392/tbKhWJr9zZfdzZv9zZv9zZfdeY3w1Na7+/2rHoLHzP7my+7mzf7mzf7ma/jduSM8oaq6IhvfyPLBJF996Hx3/4eVDcXS7G++7G7e7G/e7G++7M4d4am9Jsmzk/yJPPyPGDrJMP+Fmjn7my+7mzf7mzf7m6/hd+eO8ISq6u7u/s5Vz8FjY3/zZXfzZn/zZn/zZXc+NWJqv1ZVB1c9BI+Z/c2X3c2b/c2b/c3X8LtzR3hCVXVXNj6K5HPZeK1NJenuft5KB2Mp9jdfdjdv9jdv9jdfdieEJ1VVl253fqSPIZkz+5svu5s3+5s3+5svuxPCAAAMymuEAQAYkhAGAGBIQhgAgCEJYQAAhiSEAQAY0v8H5xxEyOKX9yQAAAAASUVORK5CYII=\n",
            "text/plain": [
              "<Figure size 864x576 with 1 Axes>"
            ]
          },
          "metadata": {
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Saving model_6\n",
        "model_6.save(\"nlp_pubmed_model_6\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "BXwJDHsbCvzR",
        "outputId": "4bddef98-00bf-4b65-fe86-6f6a88e384da"
      },
      "execution_count": 87,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "WARNING:absl:Found untraced functions such as lstm_cell_19_layer_call_fn, lstm_cell_19_layer_call_and_return_conditional_losses, lstm_cell_20_layer_call_fn, lstm_cell_20_layer_call_and_return_conditional_losses, lstm_cell_19_layer_call_fn while saving (showing 5 of 10). These functions will not be directly callable after loading.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "INFO:tensorflow:Assets written to: nlp_pubmed_model_6/assets\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "INFO:tensorflow:Assets written to: nlp_pubmed_model_6/assets\n",
            "WARNING:absl:<keras.layers.recurrent.LSTMCell object at 0x7f1d83d60e50> has the same name 'LSTMCell' as a built-in Keras object. Consider renaming <class 'keras.layers.recurrent.LSTMCell'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.\n",
            "WARNING:absl:<keras.layers.recurrent.LSTMCell object at 0x7f1d83d57bd0> has the same name 'LSTMCell' as a built-in Keras object. Consider renaming <class 'keras.layers.recurrent.LSTMCell'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!cp nlp_pubmed_model_6 -r /content/drive/MyDrive/NLP-projects/nlp_pubmed_model_6"
      ],
      "metadata": {
        "id": "vZDM_8TQC_rL"
      },
      "execution_count": 88,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Evaluating on test data"
      ],
      "metadata": {
        "id": "5NrfgQJeDwAl"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Preprocessing test dataset, evaluation & predictions\n",
        "test_pos_character_token_data = tf.data.Dataset.from_tensor_slices((test_line_numbers_one_hot,\n",
        "                                                              test_total_lines_one_hot,\n",
        "                                                              test_sentences,\n",
        "                                                              test_char))\n",
        "test_pos_character_token_labels = tf.data.Dataset.from_tensor_slices(test_labels_one_hot)\n",
        "test_pos_character_token_dataset = tf.data.Dataset.zip((test_pos_character_token_data, test_pos_character_token_labels))\n",
        "test_pos_character_token_dataset = test_pos_character_token_dataset.batch(32).prefetch(tf.data.AUTOTUNE)\n",
        "\n",
        "# Predictions\n",
        "test_pred_probs = model_6.predict(test_pos_character_token_dataset,\n",
        "                                  verbose=1)\n",
        "test_preds = tf.argmax(test_pred_probs, axis=1)\n",
        "\n",
        "# Evaluation results\n",
        "model_6_test_results = calculate_results(y_true=test_labels_encoded,\n",
        "                                         y_pred=test_preds)\n",
        "model_6_test_results\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "N5c_1WXSFS7g",
        "outputId": "2c2b334e-1471-4b90-8913-bae75d8c002f"
      },
      "execution_count": 89,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "942/942 [==============================] - 45s 48ms/step\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'accuracy': 82.87041645926664,\n",
              " 'f1': 0.8274737505797032,\n",
              " 'precision': 0.8272332364828473,\n",
              " 'recall': 0.8287041645926664}"
            ]
          },
          "metadata": {},
          "execution_count": 89
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Evaluating the most wrong predictions\n",
        "test_pred_classes = [label_encoder.classes_[pred] for pred in test_preds]\n",
        "\n",
        "# Integrating prediction in test_df \n",
        "test_df[\"prediction\"] = test_pred_classes\n",
        "test_df[\"pred_prob\"] = tf.reduce_max(test_pred_probs, axis=1).numpy()\n",
        "test_df[\"correct\"] = test_df[\"prediction\"] == test_df[\"target\"]\n",
        "\n",
        "# 200 most wrong predictions\n",
        "most_wrong_200 = test_df[test_df[\"correct\"] == False].sort_values(\"pred_prob\", ascending=False)[:200]\n",
        "most_wrong_200"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 424
        },
        "id": "g1Qo6kn8HJgg",
        "outputId": "604f73d7-4c6d-426e-c0ab-68b00c2087b8"
      },
      "execution_count": 94,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "\n",
              "  <div id=\"df-a2518e9a-29b2-4b7c-8873-b5f189ae26eb\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>target</th>\n",
              "      <th>text</th>\n",
              "      <th>line_number</th>\n",
              "      <th>total_lines</th>\n",
              "      <th>prediction</th>\n",
              "      <th>pred_prob</th>\n",
              "      <th>correct</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>8545</th>\n",
              "      <td>METHODS</td>\n",
              "      <td>pretest-posttest .</td>\n",
              "      <td>1</td>\n",
              "      <td>11</td>\n",
              "      <td>BACKGROUND</td>\n",
              "      <td>0.947419</td>\n",
              "      <td>False</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>13874</th>\n",
              "      <td>CONCLUSIONS</td>\n",
              "      <td>symptom outcomes will be assessed and estimate...</td>\n",
              "      <td>4</td>\n",
              "      <td>6</td>\n",
              "      <td>METHODS</td>\n",
              "      <td>0.943379</td>\n",
              "      <td>False</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>16633</th>\n",
              "      <td>CONCLUSIONS</td>\n",
              "      <td>clinicaltrials.gov identifier : nct@ .</td>\n",
              "      <td>19</td>\n",
              "      <td>19</td>\n",
              "      <td>BACKGROUND</td>\n",
              "      <td>0.931800</td>\n",
              "      <td>False</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>16347</th>\n",
              "      <td>BACKGROUND</td>\n",
              "      <td>to evaluate the effects of the lactic acid bac...</td>\n",
              "      <td>0</td>\n",
              "      <td>12</td>\n",
              "      <td>OBJECTIVE</td>\n",
              "      <td>0.929503</td>\n",
              "      <td>False</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2388</th>\n",
              "      <td>RESULTS</td>\n",
              "      <td>the primary endpoint is the cumulative three-y...</td>\n",
              "      <td>4</td>\n",
              "      <td>13</td>\n",
              "      <td>METHODS</td>\n",
              "      <td>0.926239</td>\n",
              "      <td>False</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>...</th>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>24794</th>\n",
              "      <td>RESULTS</td>\n",
              "      <td>we judged that informed consent would undermin...</td>\n",
              "      <td>11</td>\n",
              "      <td>13</td>\n",
              "      <td>CONCLUSIONS</td>\n",
              "      <td>0.791870</td>\n",
              "      <td>False</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>13921</th>\n",
              "      <td>RESULTS</td>\n",
              "      <td>primary outcome was in-hospital all-cause mort...</td>\n",
              "      <td>7</td>\n",
              "      <td>16</td>\n",
              "      <td>METHODS</td>\n",
              "      <td>0.791018</td>\n",
              "      <td>False</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>18003</th>\n",
              "      <td>RESULTS</td>\n",
              "      <td>this formulation produced highly significant a...</td>\n",
              "      <td>12</td>\n",
              "      <td>20</td>\n",
              "      <td>CONCLUSIONS</td>\n",
              "      <td>0.790558</td>\n",
              "      <td>False</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>29638</th>\n",
              "      <td>METHODS</td>\n",
              "      <td>significance for all tests was set at p &lt; @ .</td>\n",
              "      <td>8</td>\n",
              "      <td>12</td>\n",
              "      <td>RESULTS</td>\n",
              "      <td>0.790246</td>\n",
              "      <td>False</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>19047</th>\n",
              "      <td>RESULTS</td>\n",
              "      <td>a randomized controlled trial and resting-stat...</td>\n",
              "      <td>2</td>\n",
              "      <td>10</td>\n",
              "      <td>METHODS</td>\n",
              "      <td>0.789534</td>\n",
              "      <td>False</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "<p>200 rows × 7 columns</p>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-a2518e9a-29b2-4b7c-8873-b5f189ae26eb')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "        \n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "      \n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-a2518e9a-29b2-4b7c-8873-b5f189ae26eb button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-a2518e9a-29b2-4b7c-8873-b5f189ae26eb');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n",
              "  "
            ],
            "text/plain": [
              "            target  ... correct\n",
              "8545       METHODS  ...   False\n",
              "13874  CONCLUSIONS  ...   False\n",
              "16633  CONCLUSIONS  ...   False\n",
              "16347   BACKGROUND  ...   False\n",
              "2388       RESULTS  ...   False\n",
              "...            ...  ...     ...\n",
              "24794      RESULTS  ...   False\n",
              "13921      RESULTS  ...   False\n",
              "18003      RESULTS  ...   False\n",
              "29638      METHODS  ...   False\n",
              "19047      RESULTS  ...   False\n",
              "\n",
              "[200 rows x 7 columns]"
            ]
          },
          "metadata": {},
          "execution_count": 94
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# most commonly wrong predictions\n",
        "for row in most_wrong_200[0:20].itertuples():\n",
        "  _, target, text, line_number, total_lines, prediction, pred_prob, _ = row\n",
        "  print(f\"Target: {target}, Prediction: {prediction}, Probability: {pred_prob}, Line Number: {line_number}, Total Lines: {total_lines}\\n\")\n",
        "  print(f\"Text:\\n{text}\\n\")\n",
        "  print(\"- - - - -\\n\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "cA8GkWujL_e7",
        "outputId": "0630e122-f255-40eb-ac1f-76328962b0dd"
      },
      "execution_count": 96,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Target: METHODS, Prediction: BACKGROUND, Probability: 0.9474185705184937, Line Number: 1, Total Lines: 11\n",
            "\n",
            "Text:\n",
            "pretest-posttest .\n",
            "\n",
            "- - - - -\n",
            "\n",
            "Target: CONCLUSIONS, Prediction: METHODS, Probability: 0.9433786869049072, Line Number: 4, Total Lines: 6\n",
            "\n",
            "Text:\n",
            "symptom outcomes will be assessed and estimates of cost-effectiveness made .\n",
            "\n",
            "- - - - -\n",
            "\n",
            "Target: CONCLUSIONS, Prediction: BACKGROUND, Probability: 0.9317997097969055, Line Number: 19, Total Lines: 19\n",
            "\n",
            "Text:\n",
            "clinicaltrials.gov identifier : nct@ .\n",
            "\n",
            "- - - - -\n",
            "\n",
            "Target: BACKGROUND, Prediction: OBJECTIVE, Probability: 0.9295032620429993, Line Number: 0, Total Lines: 12\n",
            "\n",
            "Text:\n",
            "to evaluate the effects of the lactic acid bacterium lactobacillus salivarius on caries risk factors .\n",
            "\n",
            "- - - - -\n",
            "\n",
            "Target: RESULTS, Prediction: METHODS, Probability: 0.9262389540672302, Line Number: 4, Total Lines: 13\n",
            "\n",
            "Text:\n",
            "the primary endpoint is the cumulative three-year hiv incidence .\n",
            "\n",
            "- - - - -\n",
            "\n",
            "Target: CONCLUSIONS, Prediction: BACKGROUND, Probability: 0.9260937571525574, Line Number: 18, Total Lines: 18\n",
            "\n",
            "Text:\n",
            "nct@ ( clinicaltrials.gov ) .\n",
            "\n",
            "- - - - -\n",
            "\n",
            "Target: RESULTS, Prediction: BACKGROUND, Probability: 0.9218314290046692, Line Number: 8, Total Lines: 15\n",
            "\n",
            "Text:\n",
            "non-diffuse-trickling '' ) .\n",
            "\n",
            "- - - - -\n",
            "\n",
            "Target: CONCLUSIONS, Prediction: BACKGROUND, Probability: 0.9202677607536316, Line Number: 15, Total Lines: 15\n",
            "\n",
            "Text:\n",
            "-lsb- netherlands trial register ( http://www.trialregister.nl/trialreg/index.asp ) , nr @ , date of registration @ december @ . -rsb-\n",
            "\n",
            "- - - - -\n",
            "\n",
            "Target: RESULTS, Prediction: METHODS, Probability: 0.9188950657844543, Line Number: 3, Total Lines: 16\n",
            "\n",
            "Text:\n",
            "a cluster randomised trial was implemented with @,@ children in @ government primary schools on the south coast of kenya in @-@ .\n",
            "\n",
            "- - - - -\n",
            "\n",
            "Target: CONCLUSIONS, Prediction: BACKGROUND, Probability: 0.9168384671211243, Line Number: 13, Total Lines: 13\n",
            "\n",
            "Text:\n",
            "( clinicaltrials.gov : nct@ ) .\n",
            "\n",
            "- - - - -\n",
            "\n",
            "Target: RESULTS, Prediction: METHODS, Probability: 0.9142534136772156, Line Number: 4, Total Lines: 14\n",
            "\n",
            "Text:\n",
            "a screening questionnaire for moh was sent to all @-@ year old patients on these gps ` list .\n",
            "\n",
            "- - - - -\n",
            "\n",
            "Target: RESULTS, Prediction: METHODS, Probability: 0.9101817011833191, Line Number: 6, Total Lines: 14\n",
            "\n",
            "Text:\n",
            "the primary outcome was to evaluate changes in abdominal and shoulder-tip pain via a @-mm visual analog scale at @ , @ , and @hours postoperatively .\n",
            "\n",
            "- - - - -\n",
            "\n",
            "Target: METHODS, Prediction: RESULTS, Probability: 0.9094158411026001, Line Number: 6, Total Lines: 9\n",
            "\n",
            "Text:\n",
            "-@ % vs. fish : -@ % vs. fish + s : -@ % ; p < @ ) but there were no significant differences between groups .\n",
            "\n",
            "- - - - -\n",
            "\n",
            "Target: METHODS, Prediction: BACKGROUND, Probability: 0.908202588558197, Line Number: 4, Total Lines: 9\n",
            "\n",
            "Text:\n",
            "clinicaltrials.gov identifier : nct@ .\n",
            "\n",
            "- - - - -\n",
            "\n",
            "Target: BACKGROUND, Prediction: OBJECTIVE, Probability: 0.9057267904281616, Line Number: 0, Total Lines: 9\n",
            "\n",
            "Text:\n",
            "to compare the efficacy of the newcastle infant dialysis and ultrafiltration system ( nidus ) with peritoneal dialysis ( pd ) and conventional haemodialysis ( hd ) in infants weighing < @ kg .\n",
            "\n",
            "- - - - -\n",
            "\n",
            "Target: BACKGROUND, Prediction: OBJECTIVE, Probability: 0.9013079404830933, Line Number: 0, Total Lines: 11\n",
            "\n",
            "Text:\n",
            "to compare the safety and efficacy of dexmedetomidine/propofol ( dp ) - total i.v. anaesthesia ( tiva ) vs remifentanil/propofol ( rp ) - tiva , both with spontaneous breathing , during airway foreign body ( fb ) removal in children .\n",
            "\n",
            "- - - - -\n",
            "\n",
            "Target: METHODS, Prediction: OBJECTIVE, Probability: 0.8997651934623718, Line Number: 0, Total Lines: 7\n",
            "\n",
            "Text:\n",
            "to determine whether the insulin resistance that exists in metabolic syndrome ( mets ) patients is modulated by dietary fat composition .\n",
            "\n",
            "- - - - -\n",
            "\n",
            "Target: RESULTS, Prediction: CONCLUSIONS, Probability: 0.8978807330131531, Line Number: 13, Total Lines: 15\n",
            "\n",
            "Text:\n",
            "additionally , intervention effects were observed for information gathering in women with high genetic literacy , but not in women with low genetic literacy .\n",
            "\n",
            "- - - - -\n",
            "\n",
            "Target: CONCLUSIONS, Prediction: BACKGROUND, Probability: 0.8975597023963928, Line Number: 10, Total Lines: 10\n",
            "\n",
            "Text:\n",
            "clinicaltrials.gov : nct@ .\n",
            "\n",
            "- - - - -\n",
            "\n",
            "Target: RESULTS, Prediction: METHODS, Probability: 0.8973494172096252, Line Number: 3, Total Lines: 11\n",
            "\n",
            "Text:\n",
            "family practices were randomly assigned to receive the educational toolkit in june @ ( intervention group ) or may @ ( control group ) .\n",
            "\n",
            "- - - - -\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Predicting on PubMed NCBI research paper"
      ],
      "metadata": {
        "id": "aNXIEQDdN6xp"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Source** - `https://pubmed.ncbi.nlm.nih.gov/20232240/`\n",
        "\n",
        "Using the research paper from `PubMed NCBI` by `Christopher Lopata, Marcus L Thomeer, etc`.\n",
        "\n",
        "**Name** - `Randomized Controlled Trial: RCT of a manualized social treatment for high-functioning autism spectrum disorders`\n",
        "\n",
        "\n",
        "\n",
        "**Abstract**:\n",
        "\"This RCT examined the efficacy of a manualized social intervention for children with HFASDs. Participants were randomly assigned to treatment or wait-list conditions. Treatment included instruction and therapeutic activities targeting social skills, face-emotion recognition, interest expansion, and interpretation of non-literal language. A response-cost program was applied to reduce problem behaviors and foster skills acquisition. Significant treatment effects were found for five of seven primary outcome measures (parent ratings and direct child measures). Secondary measures based on staff ratings (treatment group only) corroborated gains reported by parents. High levels of parent, child and staff satisfaction were reported, along with high levels of treatment fidelity. Standardized effect size estimates were primarily in the medium and large ranges and favored the treatment group.\"\n",
        "\n",
        "**File** - Using the `Abstract` of the paper in `.json` format for readability.\n",
        "\n",
        "**Link**: `https://raw.githubusercontent.com/hecshzye/nlp-medical-abstract-pubmed-rct/main/pubmed_ncbi_autism_disorder.json`"
      ],
      "metadata": {
        "id": "RZfQ3OCURW2V"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import json\n",
        "\n",
        "# Loading the NCBI paper\n",
        "!wget https://raw.githubusercontent.com/hecshzye/nlp-medical-abstract-pubmed-rct/main/pubmed_ncbi_autism_disorder.json\n",
        "\n",
        "with open(\"pubmed_ncbi_autism_disorder.json\", \"r\") as f:\n",
        "  ncbi_abstract = json.load(f)\n",
        "\n",
        "ncbi_abstract"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "VqCJ0IoYTvv9",
        "outputId": "405f4b6c-c8bd-4dfd-8c91-a39f86450f4d"
      },
      "execution_count": 106,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "--2022-01-17 04:46:50--  https://raw.githubusercontent.com/hecshzye/nlp-medical-abstract-pubmed-rct/main/pubmed_ncbi_autism_disorder.json\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 6737 (6.6K) [text/plain]\n",
            "Saving to: ‘pubmed_ncbi_autism_disorder.json’\n",
            "\n",
            "\r          pubmed_nc   0%[                    ]       0  --.-KB/s               \rpubmed_ncbi_autism_ 100%[===================>]   6.58K  --.-KB/s    in 0s      \n",
            "\n",
            "2022-01-17 04:46:50 (76.7 MB/s) - ‘pubmed_ncbi_autism_disorder.json’ saved [6737/6737]\n",
            "\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[{'abstract': 'This RCT examined the efficacy of a manualized social intervention for children with HFASDs. Participants were randomly assigned to treatment or wait-list conditions. Treatment included instruction and therapeutic activities targeting social skills, face-emotion recognition, interest expansion, and interpretation of non-literal language. A response-cost program was applied to reduce problem behaviors and foster skills acquisition. Significant treatment effects were found for five of seven primary outcome measures (parent ratings and direct child measures). Secondary measures based on staff ratings (treatment group only) corroborated gains reported by parents. High levels of parent, child and staff satisfaction were reported, along with high levels of treatment fidelity. Standardized effect size estimates were primarily in the medium and large ranges and favored the treatment group.',\n",
              "  'details': 'RCT of a manualized social treatment for high-functioning autism spectrum disorders',\n",
              "  'source': 'https://pubmed.ncbi.nlm.nih.gov/20232240/'},\n",
              " {'abstract': \"Postpartum depression (PPD) is the most prevalent mood disorder associated with childbirth. No single cause of PPD has been identified, however the increased risk of nutritional deficiencies incurred through the high nutritional requirements of pregnancy may play a role in the pathology of depressive symptoms. Three nutritional interventions have drawn particular interest as possible non-invasive and cost-effective prevention and/or treatment strategies for PPD; omega-3 (n-3) long chain polyunsaturated fatty acids (LCPUFA), vitamin D and overall diet. We searched for meta-analyses of randomised controlled trials (RCT's) of nutritional interventions during the perinatal period with PPD as an outcome, and checked for any trials published subsequently to the meta-analyses. Fish oil: Eleven RCT's of prenatal fish oil supplementation RCT's show null and positive effects on PPD symptoms. Vitamin D: no relevant RCT's were identified, however seven observational studies of maternal vitamin D levels with PPD outcomes showed inconsistent associations. Diet: Two Australian RCT's with dietary advice interventions in pregnancy had a positive and null result on PPD. With the exception of fish oil, few RCT's with nutritional interventions during pregnancy assess PPD. Further research is needed to determine whether nutritional intervention strategies during pregnancy can protect against symptoms of PPD. Given the prevalence of PPD and ease of administering PPD measures, we recommend future prenatal nutritional RCT's include PPD as an outcome.\",\n",
              "  'details': 'Formatting removed (can be used to compare model to actual example)',\n",
              "  'source': 'https://pubmed.ncbi.nlm.nih.gov/28012571/'},\n",
              " {'abstract': 'Mental illness, including depression, anxiety and bipolar disorder, accounts for a significant proportion of global disability and poses a substantial social, economic and heath burden. Treatment is presently dominated by pharmacotherapy, such as antidepressants, and psychotherapy, such as cognitive behavioural therapy; however, such treatments avert less than half of the disease burden, suggesting that additional strategies are needed to prevent and treat mental disorders. There are now consistent mechanistic, observational and interventional data to suggest diet quality may be a modifiable risk factor for mental illness. This review provides an overview of the nutritional psychiatry field. It includes a discussion of the neurobiological mechanisms likely modulated by diet, the use of dietary and nutraceutical interventions in mental disorders, and recommendations for further research. Potential biological pathways related to mental disorders include inflammation, oxidative stress, the gut microbiome, epigenetic modifications and neuroplasticity. Consistent epidemiological evidence, particularly for depression, suggests an association between measures of diet quality and mental health, across multiple populations and age groups; these do not appear to be explained by other demographic, lifestyle factors or reverse causality. Our recently published intervention trial provides preliminary clinical evidence that dietary interventions in clinically diagnosed populations are feasible and can provide significant clinical benefit. Furthermore, nutraceuticals including n-3 fatty acids, folate, S-adenosylmethionine, N-acetyl cysteine and probiotics, among others, are promising avenues for future research. Continued research is now required to investigate the efficacy of intervention studies in large cohorts and within clinically relevant populations, particularly in patients with schizophrenia, bipolar and anxiety disorders.',\n",
              "  'details': 'Effect of nutrition on mental health',\n",
              "  'source': 'https://pubmed.ncbi.nlm.nih.gov/28942748/'},\n",
              " {'abstract': \"Hepatitis C virus (HCV) and alcoholic liver disease (ALD), either alone or in combination, count for more than two thirds of all liver diseases in the Western world. There is no safe level of drinking in HCV-infected patients and the most effective goal for these patients is total abstinence. Baclofen, a GABA(B) receptor agonist, represents a promising pharmacotherapy for alcohol dependence (AD). Previously, we performed a randomized clinical trial (RCT), which demonstrated the safety and efficacy of baclofen in patients affected by AD and cirrhosis. The goal of this post-hoc analysis was to explore baclofen's effect in a subgroup of alcohol-dependent HCV-infected cirrhotic patients. Any patient with HCV infection was selected for this analysis. Among the 84 subjects randomized in the main trial, 24 alcohol-dependent cirrhotic patients had a HCV infection; 12 received baclofen 10mg t.i.d. and 12 received placebo for 12-weeks. With respect to the placebo group (3/12, 25.0%), a significantly higher number of patients who achieved and maintained total alcohol abstinence was found in the baclofen group (10/12, 83.3%; p=0.0123). Furthermore, in the baclofen group, compared to placebo, there was a significantly higher increase in albumin values from baseline (p=0.0132) and a trend toward a significant reduction in INR levels from baseline (p=0.0716). In conclusion, baclofen was safe and significantly more effective than placebo in promoting alcohol abstinence, and improving some Liver Function Tests (LFTs) (i.e. albumin, INR) in alcohol-dependent HCV-infected cirrhotic patients. Baclofen may represent a clinically relevant alcohol pharmacotherapy for these patients.\",\n",
              "  'details': 'Baclofen promotes alcohol abstinence in alcohol dependent cirrhotic patients with hepatitis C virus (HCV) infection',\n",
              "  'source': 'https://pubmed.ncbi.nlm.nih.gov/22244707/'}]"
            ]
          },
          "metadata": {},
          "execution_count": 106
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "abstracts = pd.DataFrame(ncbi_abstract)\n",
        "# Using spacy sentencizer\n",
        "from spacy.lang.en import English\n",
        "nlp = English()\n",
        "sentencizer = nlp.create_pipe(\"sentencizer\")\n",
        "nlp.add_pipe(sentencizer)\n",
        "doc = nlp(ncbi_abstract[0][\"abstract\"])\n",
        "abstract_lines = [str(sent) for sent in list(doc.sents)]\n",
        "abstract_lines"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fX_OB4o-UcU9",
        "outputId": "fe0ee7bf-fb17-4359-88b8-b6137c7bc46a"
      },
      "execution_count": 107,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['This RCT examined the efficacy of a manualized social intervention for children with HFASDs.',\n",
              " 'Participants were randomly assigned to treatment or wait-list conditions.',\n",
              " 'Treatment included instruction and therapeutic activities targeting social skills, face-emotion recognition, interest expansion, and interpretation of non-literal language.',\n",
              " 'A response-cost program was applied to reduce problem behaviors and foster skills acquisition.',\n",
              " 'Significant treatment effects were found for five of seven primary outcome measures (parent ratings and direct child measures).',\n",
              " 'Secondary measures based on staff ratings (treatment group only) corroborated gains reported by parents.',\n",
              " 'High levels of parent, child and staff satisfaction were reported, along with high levels of treatment fidelity.',\n",
              " 'Standardized effect size estimates were primarily in the medium and large ranges and favored the treatment group.']"
            ]
          },
          "metadata": {},
          "execution_count": 107
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Preprocessing the ncbi_abstract\n",
        "total_lines_in_ncbi_abstract = len(abstract_lines)\n",
        "sample_lines = []\n",
        "for i, line in enumerate(abstract_lines):\n",
        "  sample_dict = {}\n",
        "  sample_dict[\"text\"] = str(line)\n",
        "  sample_dict[\"line_number\"] = i\n",
        "  sample_dict[\"total_lines\"] = total_lines_in_ncbi_abstract - 1\n",
        "  sample_lines.append(sample_dict)\n",
        "sample_lines  "
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "byX76JhsU9gB",
        "outputId": "5c25f915-d13c-41bd-d0d4-fd7f5f372494"
      },
      "execution_count": 108,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[{'line_number': 0,\n",
              "  'text': 'This RCT examined the efficacy of a manualized social intervention for children with HFASDs.',\n",
              "  'total_lines': 7},\n",
              " {'line_number': 1,\n",
              "  'text': 'Participants were randomly assigned to treatment or wait-list conditions.',\n",
              "  'total_lines': 7},\n",
              " {'line_number': 2,\n",
              "  'text': 'Treatment included instruction and therapeutic activities targeting social skills, face-emotion recognition, interest expansion, and interpretation of non-literal language.',\n",
              "  'total_lines': 7},\n",
              " {'line_number': 3,\n",
              "  'text': 'A response-cost program was applied to reduce problem behaviors and foster skills acquisition.',\n",
              "  'total_lines': 7},\n",
              " {'line_number': 4,\n",
              "  'text': 'Significant treatment effects were found for five of seven primary outcome measures (parent ratings and direct child measures).',\n",
              "  'total_lines': 7},\n",
              " {'line_number': 5,\n",
              "  'text': 'Secondary measures based on staff ratings (treatment group only) corroborated gains reported by parents.',\n",
              "  'total_lines': 7},\n",
              " {'line_number': 6,\n",
              "  'text': 'High levels of parent, child and staff satisfaction were reported, along with high levels of treatment fidelity.',\n",
              "  'total_lines': 7},\n",
              " {'line_number': 7,\n",
              "  'text': 'Standardized effect size estimates were primarily in the medium and large ranges and favored the treatment group.',\n",
              "  'total_lines': 7}]"
            ]
          },
          "metadata": {},
          "execution_count": 108
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Encoding\n",
        "test_abstract_line_numbers = [line[\"line_number\"] for line in sample_lines]\n",
        "test_abstract_line_numbers_one_hot = tf.one_hot(test_abstract_line_numbers, depth=15)\n",
        "test_abstract_total_lines = [line[\"total_lines\"] for line in sample_lines]\n",
        "test_abstract_total_lines_one_hot = tf.one_hot(test_abstract_total_lines, depth=20)\n",
        "# Spliting into characters\n",
        "abstract_characters = [split_character(sentence) for sentence in abstract_lines]\n",
        "\n",
        "# Predictions \n",
        "test_abstract_pred_probs = model_6.predict(x=(test_abstract_line_numbers_one_hot,\n",
        "                                              test_abstract_total_lines_one_hot,\n",
        "                                              tf.constant(abstract_lines),\n",
        "                                              tf.constant(abstract_characters)))\n",
        "test_abstract_preds = tf.argmax(test_abstract_pred_probs, axis=1)\n",
        "test_abstract_pred_classes = [label_encoder.classes_[i] for i in test_abstract_preds]\n",
        "\n",
        "for i, line in enumerate(abstract_lines):\n",
        "  print(f\"{test_abstract_pred_classes[i]}: {line}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "8IwaeIwWYDgb",
        "outputId": "db7abe9f-c18b-4d27-c12a-3e377957d176"
      },
      "execution_count": 115,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "OBJECTIVE: This RCT examined the efficacy of a manualized social intervention for children with HFASDs.\n",
            "METHODS: Participants were randomly assigned to treatment or wait-list conditions.\n",
            "METHODS: Treatment included instruction and therapeutic activities targeting social skills, face-emotion recognition, interest expansion, and interpretation of non-literal language.\n",
            "METHODS: A response-cost program was applied to reduce problem behaviors and foster skills acquisition.\n",
            "RESULTS: Significant treatment effects were found for five of seven primary outcome measures (parent ratings and direct child measures).\n",
            "METHODS: Secondary measures based on staff ratings (treatment group only) corroborated gains reported by parents.\n",
            "RESULTS: High levels of parent, child and staff satisfaction were reported, along with high levels of treatment fidelity.\n",
            "RESULTS: Standardized effect size estimates were primarily in the medium and large ranges and favored the treatment group.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Model Predictions"
      ],
      "metadata": {
        "id": "-19LUH0waM9C"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Original Abstract**:\n",
        "\n",
        "\"This RCT examined the efficacy of a manualized social intervention for children with HFASDs. Participants were randomly assigned to treatment or wait-list conditions. Treatment included instruction and therapeutic activities targeting social skills, face-emotion recognition, interest expansion, and interpretation of non-literal language. A response-cost program was applied to reduce problem behaviors and foster skills acquisition. Significant treatment effects were found for five of seven primary outcome measures (parent ratings and direct child measures). Secondary measures based on staff ratings (treatment group only) corroborated gains reported by parents. High levels of parent, child and staff satisfaction were reported, along with high levels of treatment fidelity. Standardized effect size estimates were primarily in the medium and large ranges and favored the treatment group.\""
      ],
      "metadata": {
        "id": "wet8KQcAbqqY"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Model's `Predicted` Abstract which makes Abstract easier to read**\n",
        "\n",
        "`Abstract` after `Natural Language Processing` (`model_6`):\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "**OBJECTIVE**: This RCT examined the efficacy of a manualized social intervention for children with HFASDs.\n",
        "\n",
        "**METHODS**: Participants were randomly assigned to treatment or wait-list conditions.\n",
        "\n",
        "**METHODS**: Treatment included instruction and therapeutic activities targeting social skills, face-emotion recognition, interest expansion, and interpretation of non-literal language.\n",
        "\n",
        "**METHODS**: A response-cost program was applied to reduce problem behaviors and foster skills acquisition.\n",
        "\n",
        "**RESULTS**: Significant treatment effects were found for five of seven primary outcome measures (parent ratings and direct child measures).\n",
        "\n",
        "**METHODS**: Secondary measures based on staff ratings (treatment group only) corroborated gains reported by parents.\n",
        "\n",
        "**RESULTS**: High levels of parent, child and staff satisfaction were reported, along with high levels of treatment fidelity.\n",
        "\n",
        "**RESULTS**: Standardized effect size estimates were primarily in the medium and large ranges and favored the treatment group.\n",
        "\n"
      ],
      "metadata": {
        "id": "7OmmHuG0b5xq"
      }
    },
    {
      "cell_type": "code",
      "source": [
        ""
      ],
      "metadata": {
        "id": "-fyMruj0cwPS"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}