[357738]: / Roberta+LLM / evaluate_roberta_chia.ipynb

Download this file

3217 lines (3217 with data), 125.5 kB

{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/jlopetegui98/NER-ClinicalTrials-Elegibility-Criteria/blob/main/Roberta%2BLLM/evaluate_roberta_chia.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "OmG4urkedeiv",
        "outputId": "87a33dc4-f118-45a7-c8aa-aab80e9c76ca",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
          ]
        }
      ],
      "source": [
        "# uncomment if working in colab\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "pG8-9pLtdeiv",
        "outputId": "d3c710c2-a443-4b3b-db20-2fafe6746f0d",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
            "  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for transformers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m510.5/510.5 kB\u001b[0m \u001b[31m10.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m15.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m17.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m17.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
            "  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for accelerate (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "Collecting seqeval\n",
            "  Downloading seqeval-1.2.2.tar.gz (43 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.6/43.6 kB\u001b[0m \u001b[31m1.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "Requirement already satisfied: numpy>=1.14.0 in /usr/local/lib/python3.10/dist-packages (from seqeval) (1.25.2)\n",
            "Requirement already satisfied: scikit-learn>=0.21.3 in /usr/local/lib/python3.10/dist-packages (from seqeval) (1.2.2)\n",
            "Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.21.3->seqeval) (1.11.4)\n",
            "Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.21.3->seqeval) (1.4.0)\n",
            "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.21.3->seqeval) (3.4.0)\n",
            "Building wheels for collected packages: seqeval\n",
            "  Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for seqeval: filename=seqeval-1.2.2-py3-none-any.whl size=16161 sha256=b571c9c4836027705ada02b905790e6d0b00dff2c033b522f11aeb9c3d0d66ed\n",
            "  Stored in directory: /root/.cache/pip/wheels/1a/67/4a/ad4082dd7dfc30f2abfe4d80a2ed5926a506eb8a972b4767fa\n",
            "Successfully built seqeval\n",
            "Installing collected packages: seqeval\n",
            "Successfully installed seqeval-1.2.2\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.1/84.1 kB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ],
      "source": [
        "# uncomment if using colab\n",
        "!pip install -q -U git+https://github.com/huggingface/transformers.git\n",
        "!pip install -q -U datasets\n",
        "!pip install -q -U git+https://github.com/huggingface/accelerate.git\n",
        "!pip install seqeval\n",
        "!pip install -q -U evaluate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "cD_ebmTNdeiv"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "from transformers import AutoTokenizer, AutoModelForTokenClassification, DataCollatorForTokenClassification,  Trainer, TrainingArguments\n",
        "from datasets import load_dataset, load_metric\n",
        "from seqeval.metrics import classification_report\n",
        "from seqeval.scheme import IOB2\n",
        "import evaluate\n",
        "import torch"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "RTiSFJvzdeiw",
        "outputId": "ed96941b-ab02-41bb-da96-c668cb8aac43",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 145,
          "referenced_widgets": [
            "aae03eda481e488eaf1ef5b0610cdc0b",
            "034ac69beaf14927bb2800c7e848ecbb",
            "0c7e1ad3261b4ff4876a9a8654fc9464",
            "306bc801a94541999e50632fbc3f4c0f",
            "36dbf510f6fc446383291658d394066a",
            "77affba1207543b483febb7acb086298",
            "d2605acf97fa45ad8b157868a831e4f7",
            "c116ed771b584a1398d9200fd3a8eba0",
            "f3a413d8c726471295e252f3ea122f31",
            "2e4c44c84e464804b1c757f88e46d506",
            "719dfc3140194a839f84e6efe01e01b0",
            "7e0e3e6216db4347a7de8f4553dd6282",
            "7c10261a70ee4e7184524cc2590a28c1",
            "5b93a9aedb2f49d88e5e1d9cdb13c751",
            "5ad14d3ac4614dcc903d60dcb4edbb88",
            "c218b57c46fc4f5eb9260fd08c5a94c8",
            "e377b67ff2a948e3be138334ed3de886",
            "d9a4367195b94edc9d1ba97d7b0a2eb8",
            "aa9ffe65a42643c69f553f283cdf6772",
            "8a95f6bbd81148519ec4639d8b0f4db6",
            "13acbfe9fc7a4ff4a6ec0a8be56c4f62",
            "e7bcd126a8384f94a027e130d09f2346",
            "ee0ac8c2fac44c02b28844bb5bc6822a",
            "d1444855a9cc4726a37cf9f184be8409",
            "baa7f4149cb446078d05574550f33cda",
            "002b9e99a05c46749a052c5770f280a5",
            "33a233aa855f4356960aa336d361bfdb",
            "b111ac45f0d648a8b89e1bf3d26dc354",
            "7c0010ffc2b84b7cb1c2e943e849a67b",
            "4a0ecc50b9e94408b714c578a0786f03",
            "f67c37bb203c4775b5997e50af97fae7",
            "8fe86ac5511c400d81c52bd335c96859"
          ]
        }
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "aae03eda481e488eaf1ef5b0610cdc0b"
            }
          },
          "metadata": {}
        }
      ],
      "source": [
        "from huggingface_hub import notebook_login\n",
        "\n",
        "notebook_login()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "46CEQFSvdeiw"
      },
      "outputs": [],
      "source": [
        "# dict for the entities (entity to int value)\n",
        "simple_ent = {\"Condition\", \"Value\", \"Drug\", \"Procedure\", \"Measurement\", \"Temporal\", \"Observation\", \"Person\", \"Device\"}\n",
        "sel_ent = {\n",
        "    \"O\": 0,\n",
        "    \"B-Condition\": 1,\n",
        "    \"I-Condition\": 2,\n",
        "    \"B-Value\": 3,\n",
        "    \"I-Value\": 4,\n",
        "    \"B-Drug\": 5,\n",
        "    \"I-Drug\": 6,\n",
        "    \"B-Procedure\": 7,\n",
        "    \"I-Procedure\": 8,\n",
        "    \"B-Measurement\": 9,\n",
        "    \"I-Measurement\": 10,\n",
        "    \"B-Temporal\": 11,\n",
        "    \"I-Temporal\": 12,\n",
        "    \"B-Observation\": 13,\n",
        "    \"I-Observation\": 14,\n",
        "    \"B-Person\": 15,\n",
        "    \"I-Person\": 16,\n",
        "    \"B-Device\": 17,\n",
        "    \"I-Device\": 18\n",
        "}\n",
        "\n",
        "entities_list = list(sel_ent.keys())\n",
        "sel_ent_inv = {v: k for k, v in sel_ent.items()}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "8DChGuXtdeiw"
      },
      "outputs": [],
      "source": [
        "root = '..'\n",
        "root = './drive/MyDrive/TER-LISN-2024'\n",
        "data_path = f'{root}/data'\n",
        "models_path = f'{root}/models'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "azvAU6endeiw"
      },
      "outputs": [],
      "source": [
        "model_name = \"roberta-base\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "9Lj6yRsIdeiw",
        "outputId": "c140456c-c362-4eed-c946-67e626da7f52",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: \n",
            "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
            "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
            "You will be able to reuse this secret in all of your notebooks.\n",
            "Please note that authentication is recommended but still optional to access public models or datasets.\n",
            "  warnings.warn(\n"
          ]
        }
      ],
      "source": [
        "tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "jGIUj0Wwdeiw"
      },
      "outputs": [],
      "source": [
        "# tokenize and align the labels in the dataset\n",
        "def tokenize_and_align_labels(sentence, tokenizer, flag = 'I'):\n",
        "    \"\"\"\n",
        "    Tokenize the sentence and align the labels\n",
        "    inputs:\n",
        "        sentence: dict, the sentence from the dataset\n",
        "        flag: str, the flag to indicate how to deal with the labels for subwords\n",
        "            - 'I': use the label of the first subword for all subwords but as intermediate (I-ENT)\n",
        "            - 'B': use the label of the first subword for all subwords as beginning (B-ENT)\n",
        "            - None: use -100 for subwords\n",
        "    outputs:\n",
        "        tokenized_sentence: dict, the tokenized sentence now with a field for the labels\n",
        "    \"\"\"\n",
        "    tokenized_sentence = tokenizer(sentence['tokens'], is_split_into_words=True, truncation=True, padding='max_length', max_length=512)\n",
        "\n",
        "    labels = []\n",
        "    all_word_ids = []\n",
        "    for i, labels_s in enumerate(sentence['ner_tags']):\n",
        "        word_ids = tokenized_sentence.word_ids(batch_index=i)\n",
        "        previous_word_idx = None\n",
        "        label_ids = []\n",
        "        for word_idx in word_ids:\n",
        "            # if the word_idx is None, assign -100\n",
        "            if word_idx is None:\n",
        "                label_ids.append(-100)\n",
        "            # if it is a new word, assign the corresponding label\n",
        "            elif word_idx != previous_word_idx:\n",
        "                label_ids.append(labels_s[word_idx])\n",
        "            # if it is the same word, check the flag to assign\n",
        "            else:\n",
        "                if flag == 'I':\n",
        "                    if entities_list[labels_s[word_idx]].startswith('I'):\n",
        "                      label_ids.append(labels_s[word_idx])\n",
        "                    else:\n",
        "                      label_ids.append(labels_s[word_idx] + 1)\n",
        "                elif flag == 'B':\n",
        "                    label_ids.append(labels_s[word_idx])\n",
        "                elif flag == None:\n",
        "                    label_ids.append(-100)\n",
        "            previous_word_idx = word_idx\n",
        "        labels.append(label_ids)\n",
        "        all_word_ids.append(word_ids)\n",
        "    tokenized_sentence['labels'] = labels\n",
        "    tokenized_sentence['word_ids'] = all_word_ids\n",
        "    return tokenized_sentence"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "Vn6kc7ZJdeiw"
      },
      "outputs": [],
      "source": [
        "dataset = load_dataset('JavierLopetegui/chia_v1')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "ia9F2gixdeiw",
        "outputId": "295daaec-e013-4539-c939-6996765bd8a8",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 49,
          "referenced_widgets": [
            "c134904ded5f47aabca0dfa5b67049db",
            "3c1d79df90544bdd93873579d38116a4",
            "3806b75490704196afde5977db7dca89",
            "7665c9c809964b6eb6b4a9720c5d2dac",
            "1785e46465ec42ccbb933604963d5dd6",
            "7b7dd023dbc7472998c6aac7eee1490f",
            "62a0e81ff19a4162a67e5884ce678f2c",
            "24a85f5bd83441b4943299be2d145aff",
            "18eaa2134e8e4edeb41340cebbeb8572",
            "68e1a741d3ec4fc3a26d759ecc357ac6",
            "de3e98c73f9547bebdad9b4bcb21d9a1"
          ]
        }
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Map:   0%|          | 0/1307 [00:00<?, ? examples/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "c134904ded5f47aabca0dfa5b67049db"
            }
          },
          "metadata": {}
        }
      ],
      "source": [
        "# tokenize and align the labels in the dataset\n",
        "dataset = dataset.map(lambda x: tokenize_and_align_labels(x, tokenizer, 'I'), batched = True)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "dataset"
      ],
      "metadata": {
        "id": "fGpxVHNyfRJw",
        "outputId": "d1268582-6fc8-4531-b5af-9ea4406f4ba9",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "DatasetDict({\n",
              "    train: Dataset({\n",
              "        features: ['tokens', 'ner_tags', 'file', 'index', 'input_ids', 'attention_mask', 'labels', 'word_ids'],\n",
              "        num_rows: 8881\n",
              "    })\n",
              "    test: Dataset({\n",
              "        features: ['tokens', 'ner_tags', 'file', 'index', 'input_ids', 'attention_mask', 'labels', 'word_ids'],\n",
              "        num_rows: 1307\n",
              "    })\n",
              "    val: Dataset({\n",
              "        features: ['tokens', 'ner_tags', 'file', 'index', 'input_ids', 'attention_mask', 'labels', 'word_ids'],\n",
              "        num_rows: 2221\n",
              "    })\n",
              "})"
            ]
          },
          "metadata": {},
          "execution_count": 11
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "smLH3_Andeiw"
      },
      "outputs": [],
      "source": [
        "# load model\n",
        "model = torch.load(f'{models_path}/roberta-ner-chia.pt')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "GiKIdJXedeiw"
      },
      "outputs": [],
      "source": [
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "data_for_model = dataset['test'].remove_columns(['file', 'tokens', 'labels', 'index', 'ner_tags', 'word_ids'])"
      ],
      "metadata": {
        "id": "k-wSN1KUfNix"
      },
      "execution_count": 14,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "data_for_model"
      ],
      "metadata": {
        "id": "OAeZ75r-hPCH",
        "outputId": "0d8c7676-3471-4f74-d3ad-942b26595c62",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Dataset({\n",
              "    features: ['input_ids', 'attention_mask'],\n",
              "    num_rows: 1307\n",
              "})"
            ]
          },
          "metadata": {},
          "execution_count": 15
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "b99O2uACdeiw"
      },
      "outputs": [],
      "source": [
        "data_loader = torch.utils.data.DataLoader(data_for_model, batch_size=8)"
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "MtnU6v5Je8X9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "enBj0Mn5deiw",
        "outputId": "2c99fd25-35ad-4cde-d2cd-05b558f6eaf4",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "RobertaForTokenClassification(\n",
              "  (roberta): RobertaModel(\n",
              "    (embeddings): RobertaEmbeddings(\n",
              "      (word_embeddings): Embedding(50265, 768, padding_idx=1)\n",
              "      (position_embeddings): Embedding(514, 768, padding_idx=1)\n",
              "      (token_type_embeddings): Embedding(1, 768)\n",
              "      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
              "      (dropout): Dropout(p=0.1, inplace=False)\n",
              "    )\n",
              "    (encoder): RobertaEncoder(\n",
              "      (layer): ModuleList(\n",
              "        (0-11): 12 x RobertaLayer(\n",
              "          (attention): RobertaAttention(\n",
              "            (self): RobertaSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.1, inplace=False)\n",
              "            )\n",
              "            (output): RobertaSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
              "              (dropout): Dropout(p=0.1, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): RobertaIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): RobertaOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
              "            (dropout): Dropout(p=0.1, inplace=False)\n",
              "          )\n",
              "        )\n",
              "      )\n",
              "    )\n",
              "  )\n",
              "  (dropout): Dropout(p=0.1, inplace=False)\n",
              "  (classifier): Linear(in_features=768, out_features=19, bias=True)\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 17
        }
      ],
      "source": [
        "model.to(device)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from tqdm import tqdm"
      ],
      "metadata": {
        "id": "F6dReMqdeva6"
      },
      "execution_count": 18,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "len(data_loader.dataset[2]['attention_mask'])"
      ],
      "metadata": {
        "id": "vPBZwZGnfzG1",
        "outputId": "60979f60-e8f8-4fe9-8959-9a8ca30027cc",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "512"
            ]
          },
          "metadata": {},
          "execution_count": 19
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "UtZKgTQcgORE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {
        "id": "QzmBnKxjdeiw",
        "outputId": "7d40eef4-7f54-4faf-c2e0-29f48313a3f9",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 164/164 [00:50<00:00,  3.22it/s]\n"
          ]
        }
      ],
      "source": [
        "labels = []\n",
        "for batch in tqdm(data_loader):\n",
        "\n",
        "    batch['input_ids'] = torch.LongTensor(np.column_stack(np.array(batch['input_ids']))).to(device)\n",
        "    batch['attention_mask'] = torch.LongTensor(np.column_stack(np.array(batch['attention_mask']))).to(device)\n",
        "    batch_tokenizer = {'input_ids': batch['input_ids'], 'attention_mask': batch['attention_mask']}\n",
        "    # break\n",
        "    with torch.no_grad():\n",
        "        outputs = model(**batch_tokenizer)\n",
        "\n",
        "    labels_batch = torch.argmax(outputs.logits, dim=2).to('cpu').numpy()\n",
        "    labels.extend([list(labels_batch[i]) for i in range(labels_batch.shape[0])])\n",
        "\n",
        "    del batch\n",
        "    del outputs\n",
        "    torch.cuda.empty_cache()"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def annotate_sentences(dataset, labels, entities_list,criteria = 'first_label'):\n",
        "    \"\"\"\n",
        "    Annotate the sentences with the predicted labels\n",
        "    inputs:\n",
        "        dataset: dataset, dataset with the sentences\n",
        "        labels: list, list of labels\n",
        "        entities_list: list, list of entities\n",
        "        criteria: str, criteria to use to select the label when the words pices have different labels\n",
        "            - first_label: select the first label\n",
        "            - majority: select the label with the majority\n",
        "    outputs:\n",
        "        annotated_sentences: list, list of annotated sentences\n",
        "    \"\"\"\n",
        "    annotated_sentences = []\n",
        "    for i in range(len(dataset)):\n",
        "        # get just the tokens different from None\n",
        "        sentence = dataset[i]\n",
        "        word_ids = sentence['word_ids']\n",
        "        sentence_labels = labels[i]\n",
        "        annotated_sentence = [[] for _ in range(len(dataset[i]['tokens']))]\n",
        "        for word_id, label in zip(word_ids, sentence_labels):\n",
        "            if word_id is not None:\n",
        "                annotated_sentence[word_id].append(label)\n",
        "        annotated_sentence_filtered = []\n",
        "        if criteria == 'first_label':\n",
        "            annotated_sentence_filtered = [annotated_sentence[i][0] for i in range(len(annotated_sentence))]\n",
        "        elif criteria == 'majority':\n",
        "            annotated_sentence_filtered = []\n",
        "            for j in range(len(annotated_sentence)):\n",
        "                starts_flag = entities_list[annotated_sentence[j][0]].startswith('B')\n",
        "\n",
        "                ent = max(set(annotated_sentence[j]), key=annotated_sentence[j].count)\n",
        "                if starts_flag and ent != 0:\n",
        "                    label = entities_list[ent][2:]\n",
        "                    label = 'B-' + label\n",
        "                    annotated_sentence_filtered.append(sel_ent[label])\n",
        "                else:\n",
        "                    annotated_sentence_filtered.append(ent)\n",
        "        annotated_sentences.append(annotated_sentence_filtered)\n",
        "    return annotated_sentences"
      ],
      "metadata": {
        "id": "w6MABN-NsXVR"
      },
      "execution_count": 21,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "annotated_sentences_first = annotate_sentences(dataset['test'], labels, entities_list, criteria='first_label')\n",
        "annotated_sentences_max = annotate_sentences(dataset['test'], labels, entities_list, criteria='majority')"
      ],
      "metadata": {
        "id": "UZVjibUbsasA"
      },
      "execution_count": 22,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "h06Yl1uTdeix",
        "outputId": "f23119cd-e70b-44e3-b069-b6d2781a06f7",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 176,
          "referenced_widgets": [
            "99188e0dd50f4d5a874eaec3101ceab6",
            "b289d32f0876481b97a080552549f7d1",
            "499e568745f743aea5c0b2ac64ef160f",
            "12eafe87e2b84edfa7d72b93361d3e6a",
            "61a45966b4654a7db975880b0a193139",
            "95909a3c6f6245d1acd056f28bfe5ba8",
            "feebbbdfbc3740bba039e95ddf31e4c1",
            "1eed478ce7694c4cb7e2c998963cafcc",
            "677a609737064c3c86de6503530b6268",
            "3736022fd7b5476dbfe03fee20d19f09",
            "0846105c8d484656848ff41c571eb71f"
          ]
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "<ipython-input-38-653dc96d1cff>:2: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate\n",
            "  metric = load_metric(\"seqeval\")\n",
            "/usr/local/lib/python3.10/dist-packages/datasets/load.py:756: FutureWarning: The repository for seqeval contains custom code which must be executed to correctly load the metric. You can inspect the repository content at https://raw.githubusercontent.com/huggingface/datasets/2.18.0/metrics/seqeval/seqeval.py\n",
            "You can avoid this message in future by passing the argument `trust_remote_code=True`.\n",
            "Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading builder script:   0%|          | 0.00/2.47k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "99188e0dd50f4d5a874eaec3101ceab6"
            }
          },
          "metadata": {}
        }
      ],
      "source": [
        "#load seqeval metric for evaluation\n",
        "metric = load_metric(\"seqeval\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0jHgq7Dldeix"
      },
      "outputs": [],
      "source": [
        "def compute_metrics_tr(p):\n",
        "    \"\"\"\n",
        "    Compute the metrics for the model\n",
        "    inputs:\n",
        "        p: tuple, the predictions and the labels\n",
        "    outputs:\n",
        "        dict: the metrics\n",
        "    \"\"\"\n",
        "    predictions, labels = p\n",
        "\n",
        "    # Remove ignored index (special tokens)\n",
        "    true_predictions = [\n",
        "        [entities_list[p] for (p, l) in zip(prediction, label) if l != -100]\n",
        "        for prediction, label in zip(predictions, labels)\n",
        "    ]\n",
        "    true_labels = [\n",
        "        [entities_list[l] for (p, l) in zip(prediction, label) if l != -100]\n",
        "        for prediction, label in zip(predictions, labels)\n",
        "    ]\n",
        "\n",
        "    results = metric.compute(predictions=true_predictions, references=true_labels)\n",
        "    resutls_strict = metric.compute(predictions=true_predictions, references=true_labels, mode='strict', scheme='IOB2')\n",
        "\n",
        "    cr1 = classification_report(true_labels, true_predictions)\n",
        "    cr2 = classification_report(true_labels, true_predictions, mode='strict', scheme=IOB2)\n",
        "\n",
        "    return results, resutls_strict,cr1,cr2"
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "wYo2jYRfp-qK"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def get_labels(p):\n",
        "    predictions, labels = p\n",
        "    # Remove ignored index (special tokens)\n",
        "    predictions = [\n",
        "        [entities_list[p] for (p, l) in zip(prediction, label) if l != -100]\n",
        "        for prediction, label in zip(predictions, labels)\n",
        "    ]\n",
        "    labels = [\n",
        "        [entities_list[l] for (p, l) in zip(prediction, label) if l != -100]\n",
        "        for prediction, label in zip(predictions, labels)\n",
        "    ]\n",
        "\n",
        "    return predictions, labels\n",
        "\n",
        "\n",
        "\n",
        "\n"
      ],
      "metadata": {
        "id": "kW2qWh4XppbT"
      },
      "execution_count": 23,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "pred_labels, true_labels = get_labels((annotated_sentences_first, dataset['test']['ner_tags']))"
      ],
      "metadata": {
        "id": "T8EIf4d83TXK"
      },
      "execution_count": 24,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "pred_labels[0]"
      ],
      "metadata": {
        "id": "MvgnYYJh608X",
        "outputId": "8e666e13-0733-4af0-cf6e-516da4407bb2",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "execution_count": 40,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['O',\n",
              " 'O',\n",
              " 'O',\n",
              " 'B-Condition',\n",
              " 'B-Person',\n",
              " 'B-Value',\n",
              " 'O',\n",
              " 'I-Value',\n",
              " 'I-Value',\n",
              " 'I-Value',\n",
              " 'I-Value',\n",
              " 'I-Value',\n",
              " 'O',\n",
              " 'O',\n",
              " 'O',\n",
              " 'I-Observation',\n",
              " 'O',\n",
              " 'O']"
            ]
          },
          "metadata": {},
          "execution_count": 40
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# from eval_file import *\n",
        "\n",
        "import argparse\n",
        "from collections import defaultdict\n",
        "from itertools import chain\n",
        "from math import pow\n",
        "from pathlib import Path\n",
        "\n",
        "# from common_utils.common_io import load_bio_file_into_sents\n",
        "# from common_utils.common_log import create_logger\n",
        "# -*- coding: utf-8 -*-\n",
        "\n",
        "# -*- coding: utf-8 -*-\n",
        "\n",
        "import json\n",
        "import pickle as pkl\n",
        "\n",
        "\n",
        "def read_from_file(ifn):\n",
        "    with open(ifn, \"r\") as f:\n",
        "        text = f.read()\n",
        "    return text\n",
        "\n",
        "\n",
        "def write_to_file(text, ofn):\n",
        "    with open(ofn, \"w\") as f:\n",
        "        f.write(text)\n",
        "    return True\n",
        "\n",
        "\n",
        "def pkl_load(ifn):\n",
        "    with open(ifn, \"rb\") as f:\n",
        "        pdata = pkl.load(f)\n",
        "    return pdata\n",
        "\n",
        "\n",
        "def pkl_dump(pdata, ofn):\n",
        "    with open(ofn, \"wb\") as f:\n",
        "        pkl.dump(pdata, f)\n",
        "    return True\n",
        "\n",
        "\n",
        "def json_load(ifn):\n",
        "    with open(ifn, \"r\") as f:\n",
        "        jdata = json.load(f)\n",
        "    return jdata\n",
        "\n",
        "\n",
        "def json_dump(jdata, ofn):\n",
        "    with open(ofn, \"w\") as f:\n",
        "        json.dump(jdata, f)\n",
        "    return True\n",
        "\n",
        "\n",
        "def load_bio_file_into_sents(bio_file, word_sep=\" \", do_lower=False):\n",
        "    bio_text = read_from_file(bio_file)\n",
        "    bio_text = bio_text.strip()\n",
        "    if do_lower:\n",
        "        bio_text = bio_text.lower()\n",
        "\n",
        "    new_sents = []\n",
        "    sents = bio_text.split(\"\\n\\n\")\n",
        "\n",
        "    for sent in sents:\n",
        "        new_sent = []\n",
        "        words = sent.split(\"\\n\")\n",
        "        for word in words:\n",
        "            new_word = word.split(word_sep)\n",
        "            new_sent.append(new_word)\n",
        "        new_sents.append(new_sent)\n",
        "\n",
        "    return new_sents\n",
        "\n",
        "\n",
        "def output_bio(bio_data, output_file, sep=\" \"):\n",
        "    with open(output_file, \"w\") as f:\n",
        "        for sent in bio_data:\n",
        "            for word in sent:\n",
        "                line = sep.join(word)\n",
        "                f.write(line)\n",
        "                f.write(\"\\n\")\n",
        "            f.write(\"\\n\")\n",
        "\n",
        "\n",
        "class PRF:\n",
        "    def __init__(self):\n",
        "        self.true = 0\n",
        "        self.false = 0\n",
        "\n",
        "    def add_true_case(self):\n",
        "        self.true += 1\n",
        "\n",
        "    def add_false_case(self):\n",
        "        self.false += 1\n",
        "\n",
        "    def get_true_false_counts(self):\n",
        "        return self.true, self.false\n",
        "\n",
        "    def __str__(self):\n",
        "        return str(self.__dict__)\n",
        "\n",
        "\n",
        "class BioEval:\n",
        "    def __init__(self):\n",
        "        self.acc = PRF()\n",
        "        # prediction\n",
        "        self.all_strict = PRF()\n",
        "        self.all_relax = PRF()\n",
        "        self.cat_strict = defaultdict(PRF)\n",
        "        self.cat_relax = defaultdict(PRF)\n",
        "        # gold standard\n",
        "        self.gs_all = 0\n",
        "        self.gs_cat = defaultdict(int)\n",
        "        self.performance = dict()\n",
        "        self.counts = dict()\n",
        "        self.beta = 1\n",
        "        self.label_not_for_eval = {'o'}\n",
        "\n",
        "    def reset(self):\n",
        "        self.acc = PRF()\n",
        "        self.all_strict = PRF()\n",
        "        self.all_relax = PRF()\n",
        "        self.cat_strict = defaultdict(PRF)\n",
        "        self.cat_relax = defaultdict(PRF)\n",
        "        self.gs_all = 0\n",
        "        self.gs_cat = defaultdict(int)\n",
        "        self.performance = dict()\n",
        "        self.counts = dict()\n",
        "\n",
        "    def set_beta_for_f_score(self, beta):\n",
        "        print(\"Using beta={} for calculating F-score\".format(beta))\n",
        "        self.beta = beta\n",
        "\n",
        "    # def set_logger(self, logger):\n",
        "    #     self.logger = logger\n",
        "\n",
        "    def add_labels_not_for_eval(self, *labels):\n",
        "        for each in labels:\n",
        "            self.label_not_for_eval.add(each.lower())\n",
        "\n",
        "    def __calc_prf(self, tp, fp, tp_tn):\n",
        "        \"\"\"\n",
        "        Using this function to calculate F-beta score, beta=1 is f_score-score, set beta=2 favor recall, and set beta=0.5 favor precision.\n",
        "        Using set_beta_for_f_score function to change beta value.\n",
        "        \"\"\"\n",
        "        tp_fp = tp + fp\n",
        "        pre = 1.0 * tp / tp_fp if tp_fp > 0 else 0.0\n",
        "        rec = 1.0 * tp / tp_tn if tp_tn > 0 else 0.0\n",
        "        beta2 = pow(self.beta, 2)\n",
        "        f_beta = (1 + beta2) * pre * rec / (beta2 * pre + rec) if (pre + rec) > 0 else 0.0\n",
        "        return pre, rec, f_beta\n",
        "\n",
        "    def __measure_performance(self):\n",
        "        self.performance['overall'] = dict()\n",
        "\n",
        "        acc_true_num, acc_false_num = self.acc.get_true_false_counts()\n",
        "        total_acc_num = acc_true_num + acc_false_num\n",
        "        # calc acc\n",
        "        overall_acc = round(1.0 * acc_true_num / total_acc_num, 4) if total_acc_num > 0 else 0.0\n",
        "        self.performance['overall']['acc'] = overall_acc\n",
        "\n",
        "        strict_true_counts, strict_false_counts = self.all_strict.get_true_false_counts()\n",
        "        strict_pre, strict_rec, strict_f_score = self.__calc_prf(strict_true_counts, strict_false_counts, self.gs_all)\n",
        "        self.performance['overall']['strict'] = dict()\n",
        "        self.performance['overall']['strict']['precision'] = strict_pre\n",
        "        self.performance['overall']['strict']['recall'] = strict_rec\n",
        "        self.performance['overall']['strict']['f_score'] = strict_f_score\n",
        "\n",
        "        relax_true_counts, relax_false_counts = self.all_relax.get_true_false_counts()\n",
        "        relax_pre, relax_rec, relax_f_score = self.__calc_prf(relax_true_counts, relax_false_counts, self.gs_all)\n",
        "        self.performance['overall']['relax'] = dict()\n",
        "        self.performance['overall']['relax']['precision'] = relax_pre\n",
        "        self.performance['overall']['relax']['recall'] = relax_rec\n",
        "        self.performance['overall']['relax']['f_score'] = relax_f_score\n",
        "\n",
        "        self.performance['category'] = dict()\n",
        "        self.performance['category']['strict'] = dict()\n",
        "        for k, v in self.cat_strict.items():\n",
        "            self.performance['category']['strict'][k] = dict()\n",
        "            stc, sfc = v.get_true_false_counts()\n",
        "            p, r, f = self.__calc_prf(stc, sfc, self.gs_cat[k])\n",
        "            self.performance['category']['strict'][k]['precision'] = p\n",
        "            self.performance['category']['strict'][k]['recall'] = r\n",
        "            self.performance['category']['strict'][k]['f_score'] = f\n",
        "\n",
        "        self.performance['category']['relax'] = dict()\n",
        "        for k, v in self.cat_relax.items():\n",
        "            self.performance['category']['relax'][k] = dict()\n",
        "            rtc, rfc = v.get_true_false_counts()\n",
        "            p, r, f = self.__calc_prf(rtc, rfc, self.gs_cat[k])\n",
        "            self.performance['category']['relax'][k]['precision'] = p\n",
        "            self.performance['category']['relax'][k]['recall'] = r\n",
        "            self.performance['category']['relax'][k]['f_score'] = f\n",
        "\n",
        "    def __measure_counts(self):\n",
        "        # gold standard\n",
        "        self.counts['expect'] = dict()\n",
        "        self.counts['expect']['overall'] = self.gs_all\n",
        "        for k, v in self.gs_cat.items():\n",
        "            self.counts['expect'][k] = v\n",
        "        # prediction\n",
        "        self.counts['prediction'] = {'strict': dict(), 'relax': dict()}\n",
        "        # strict\n",
        "        strict_true_counts, strict_false_counts = self.all_strict.get_true_false_counts()\n",
        "        self.counts['prediction']['strict']['overall'] = dict()\n",
        "        self.counts['prediction']['strict']['overall']['total'] = strict_true_counts + strict_false_counts\n",
        "        self.counts['prediction']['strict']['overall']['true'] = strict_true_counts\n",
        "        self.counts['prediction']['strict']['overall']['false'] = strict_false_counts\n",
        "        for k, v in self.cat_strict.items():\n",
        "            t, f = v.get_true_false_counts()\n",
        "            self.counts['prediction']['strict'][k] = dict()\n",
        "            self.counts['prediction']['strict'][k]['total'] = t + f\n",
        "            self.counts['prediction']['strict'][k]['true'] = t\n",
        "            self.counts['prediction']['strict'][k]['false'] = f\n",
        "        # relax\n",
        "        relax_true_counts, relax_false_counts = self.all_relax.get_true_false_counts()\n",
        "        self.counts['prediction']['relax']['overall'] = dict()\n",
        "        self.counts['prediction']['relax']['overall']['total'] = relax_true_counts + relax_false_counts\n",
        "        self.counts['prediction']['relax']['overall']['true'] = relax_true_counts\n",
        "        self.counts['prediction']['relax']['overall']['false'] = relax_false_counts\n",
        "        for k, v in self.cat_relax.items():\n",
        "            t, f = v.get_true_false_counts()\n",
        "            self.counts['prediction']['relax'][k] = dict()\n",
        "            self.counts['prediction']['relax'][k]['total'] = t + f\n",
        "            self.counts['prediction']['relax'][k]['true'] = t\n",
        "            self.counts['prediction']['relax'][k]['false'] = f\n",
        "\n",
        "    @staticmethod\n",
        "    def __strict_match(gs, pred, s_idx, e_idx, en_type):\n",
        "        if e_idx < len(gs) and gs[e_idx] == f\"i-{en_type}\":\n",
        "            # check token after end in GS is not continued entity token\n",
        "            return False\n",
        "        elif gs[s_idx] != f\"b-{en_type}\" or pred[s_idx] != f\"b-{en_type}\":\n",
        "            # force first token to be B-\n",
        "            return False\n",
        "        # check every token in span is the same\n",
        "        for idx in range(s_idx, e_idx):\n",
        "            if gs[idx] != pred[idx]:\n",
        "                return False\n",
        "        return True\n",
        "\n",
        "    @staticmethod\n",
        "    def __relax_match(gs, pred, s_idx, e_idx, en_type):\n",
        "        # we adopt the partial match strategy which is very loose compare to right-left or approximate match\n",
        "        for idx in range(s_idx, e_idx):\n",
        "            gs_cate = gs[idx].split(\"-\")[-1]\n",
        "            pred_bound, pred_cate = pred[idx].split(\"-\")\n",
        "            if gs_cate == pred_cate == en_type:\n",
        "                return True\n",
        "        return False\n",
        "\n",
        "    @staticmethod\n",
        "    def __check_evaluated_already(gs_dict, cate, start_idx, end_idx):\n",
        "        for k, v in gs_dict.items():\n",
        "            c, s, e = k\n",
        "            if not (e < start_idx or s > end_idx) and c == cate:\n",
        "                if v == 0:\n",
        "                    return True\n",
        "                else:\n",
        "                    gs_dict[k] -= 1\n",
        "                    return False\n",
        "        return False\n",
        "\n",
        "    def __process_bio(self, gs_bio, pred_bio):\n",
        "        # measure acc\n",
        "        for w_idx, (gs_word, pred_word) in enumerate(zip(gs_bio, pred_bio)):\n",
        "            # measure acc\n",
        "            if gs_word == pred_word:\n",
        "                self.acc.add_true_case()\n",
        "            else:\n",
        "                self.acc.add_false_case()\n",
        "\n",
        "        # process gold standard\n",
        "        llen = len(gs_bio)\n",
        "        gs_dict = defaultdict(int)\n",
        "        cur_idx = 0\n",
        "        while cur_idx < llen:\n",
        "            if gs_bio[cur_idx].strip() in self.label_not_for_eval:\n",
        "                cur_idx += 1\n",
        "            else:\n",
        "                start_idx = cur_idx\n",
        "                end_idx = start_idx + 1\n",
        "                _, cate = gs_bio[start_idx].strip().split('-')\n",
        "                while end_idx < llen and gs_bio[end_idx].strip() == f\"i-{cate}\":\n",
        "                    end_idx += 1\n",
        "                self.gs_all += 1\n",
        "                self.gs_cat[cate] += 1\n",
        "                gs_dict[(cate, start_idx, end_idx)] += 1\n",
        "                cur_idx = end_idx\n",
        "        # process predictions\n",
        "        cur_idx = 0\n",
        "        while cur_idx < llen:\n",
        "            if pred_bio[cur_idx].strip() in self.label_not_for_eval:\n",
        "                cur_idx += 1\n",
        "            else:\n",
        "                start_idx = cur_idx\n",
        "                end_idx = start_idx + 1\n",
        "                _, cate = pred_bio[start_idx].strip().split(\"-\")\n",
        "                while end_idx < llen and pred_bio[end_idx].strip() == f\"i-{cate}\":\n",
        "                    end_idx += 1\n",
        "                if self.__strict_match(gs_bio, pred_bio, start_idx, end_idx, cate):\n",
        "                    self.all_strict.add_true_case()\n",
        "                    self.cat_strict[cate].add_true_case()\n",
        "                    self.all_relax.add_true_case()\n",
        "                    self.cat_relax[cate].add_true_case()\n",
        "                elif self.__relax_match(gs_bio, pred_bio, start_idx, end_idx, cate):\n",
        "                    if self.__check_evaluated_already(gs_dict, cate, start_idx, end_idx):\n",
        "                        cur_idx = end_idx\n",
        "                        continue\n",
        "                    self.all_strict.add_false_case()\n",
        "                    self.cat_strict[cate].add_false_case()\n",
        "                    self.all_relax.add_true_case()\n",
        "                    self.cat_relax[cate].add_true_case()\n",
        "                else:\n",
        "                    self.all_strict.add_false_case()\n",
        "                    self.cat_strict[cate].add_false_case()\n",
        "                    self.all_relax.add_false_case()\n",
        "                    self.cat_relax[cate].add_false_case()\n",
        "                cur_idx = end_idx\n",
        "\n",
        "    def eval_file(self, gs_file, pred_file):\n",
        "        print(\"processing gold standard file: {} and prediciton file: {}\".format(gs_file, pred_file))\n",
        "        pred_bio_sents = load_bio_file_into_sents(pred_file, do_lower=True)\n",
        "        gs_bio_sents = load_bio_file_into_sents(gs_file, do_lower=True)\n",
        "        # process bio data\n",
        "        # check two data have same amount of sents\n",
        "        assert len(gs_bio_sents) == len(pred_bio_sents), \\\n",
        "            \"gold standard and prediction have different dimension: gs: {}; pred: {}\".format(len(gs_bio_sents), len(pred_bio_sents))\n",
        "        # measure performance\n",
        "        for s_idx, (gs_sent, pred_sent) in enumerate(zip(gs_bio_sents, pred_bio_sents)):\n",
        "            # check two sents have same No. of words\n",
        "            assert len(gs_sent) == len(pred_sent), \\\n",
        "                \"In {}th sentence, the words counts are different; gs: {}; pred: {}\".format(s_idx, gs_sent, pred_sent)\n",
        "            gs_sent = list(map(lambda x: x[-1], gs_sent))\n",
        "            pred_sent = list(map(lambda x: x[-1], pred_sent))\n",
        "            self.__process_bio(gs_sent, pred_sent)\n",
        "        # get the evaluation matrix\n",
        "        self.__measure_performance()\n",
        "        self.__measure_counts()\n",
        "\n",
        "    def eval_mem(self, gs, pred, do_flat=False):\n",
        "        # flat sents to sent; we assume input sequences only have 1 dimension (only labels)\n",
        "        if do_flat:\n",
        "            print('Sentences have been flatten to 1 dim.')\n",
        "            gs = list(chain(*gs))\n",
        "            pred = list(chain(*pred))\n",
        "            gs = list(map(lambda x: x.lower(), gs))\n",
        "            pred = list(map(lambda x: x.lower(), pred))\n",
        "            self.__process_bio(gs, pred)\n",
        "        else:\n",
        "            for sidx, (gs_s, pred_s) in enumerate(zip(gs, pred)):\n",
        "                gs_s = list(map(lambda x: x.lower(), gs_s))\n",
        "                pred_s = list(map(lambda x: x.lower(), pred_s))\n",
        "                self.__process_bio(gs_s, pred_s)\n",
        "\n",
        "        self.__measure_performance()\n",
        "        self.__measure_counts()\n",
        "\n",
        "    def evaluate_annotations(self, gs, pred, do_lower=False):\n",
        "        for gs_sent, pred_sent in zip(gs, pred):\n",
        "            if do_lower:\n",
        "              gs_sent = list(map(lambda x: x.lower(), gs_sent))\n",
        "              pred_sent = list(map(lambda x: x.lower(), pred_sent))\n",
        "            self.__process_bio(gs_sent, pred_sent)\n",
        "\n",
        "        self.__measure_performance()\n",
        "        self.__measure_counts()\n",
        "\n",
        "    def get_performance(self):\n",
        "        return self.performance\n",
        "\n",
        "    def get_counts(self):\n",
        "        return self.counts\n",
        "\n",
        "    def save_evaluation(self, file):\n",
        "        with open(file, \"w\") as f:\n",
        "            json.dump(self.performance, f)\n",
        "\n",
        "    def show_evaluation(self, digits=4):\n",
        "        if len(self.performance) == 0:\n",
        "            raise RuntimeError('call eval_mem() first to get the performance attribute')\n",
        "\n",
        "        cate = self.performance['category']['strict'].keys()\n",
        "\n",
        "        headers = ['precision', 'recall', 'f1']\n",
        "        width = max(max([len(c) for c in cate]), len('overall'), digits)\n",
        "        head_fmt = '{:>{width}s} ' + ' {:>9}' * len(headers)\n",
        "\n",
        "        report = head_fmt.format(u'', *headers, width=width)\n",
        "        report += '\\n\\nstrict\\n'\n",
        "\n",
        "        row_fmt = '{:>{width}s} ' + ' {:>9.{digits}f}' * 3 + '\\n'\n",
        "        for c in cate:\n",
        "            precision = self.performance['category']['strict'][c]['precision']\n",
        "            recall = self.performance['category']['strict'][c]['recall']\n",
        "            f1 = self.performance['category']['strict'][c]['f_score']\n",
        "            report += row_fmt.format(c, *[precision, recall, f1], width=width, digits=digits)\n",
        "\n",
        "        report += '\\nrelax\\n'\n",
        "\n",
        "        for c in cate:\n",
        "            precision = self.performance['category']['relax'][c]['precision']\n",
        "            recall = self.performance['category']['relax'][c]['recall']\n",
        "            f1 = self.performance['category']['relax'][c]['f_score']\n",
        "            report += row_fmt.format(c, *[precision, recall, f1], width=width, digits=digits)\n",
        "\n",
        "        report += '\\n\\noverall\\n'\n",
        "        report += 'acc: ' + str(self.performance['overall']['acc'])\n",
        "        report += '\\nstrict\\n'\n",
        "        report += row_fmt.format('', *[self.performance['overall']['strict']['precision'],\n",
        "                                       self.performance['overall']['strict']['recall'],\n",
        "                                       self.performance['overall']['strict']['f_score']], width=width, digits=digits)\n",
        "\n",
        "        report += '\\nrelax\\n'\n",
        "        report += row_fmt.format('', *[self.performance['overall']['relax']['precision'],\n",
        "                                       self.performance['overall']['relax']['recall'],\n",
        "                                       self.performance['overall']['relax']['f_score']], width=width, digits=digits)\n",
        "        return report\n"
      ],
      "metadata": {
        "id": "c0uiL0XA3dnz"
      },
      "execution_count": 32,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "s = \"i-\""
      ],
      "metadata": {
        "id": "WFBRwTMN8v2d"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "evaluator = BioEval()"
      ],
      "metadata": {
        "id": "6l6fW5Bd6MMK"
      },
      "execution_count": 33,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "evaluator.evaluate_annotations(true_labels, pred_labels, do_lower=True)"
      ],
      "metadata": {
        "id": "obqUWCw-6T90"
      },
      "execution_count": 34,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "evaluator.performance"
      ],
      "metadata": {
        "id": "XAdRmNx39MYa",
        "outputId": "f2b1b26b-9ae1-4ce1-f513-8002bbc4d809",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "execution_count": 38,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'overall': {'acc': 0.8351,\n",
              "  'strict': {'precision': 0.6225968648328897,\n",
              "   'recall': 0.6740313800832533,\n",
              "   'f_score': 0.6472939729397292},\n",
              "  'relax': {'precision': 0.7580597456373854,\n",
              "   'recall': 0.8206852385526737,\n",
              "   'f_score': 0.7881303813038131}},\n",
              " 'category': {'strict': {'condition': {'precision': 0.6648394675019577,\n",
              "    'recall': 0.7683257918552037,\n",
              "    'f_score': 0.7128463476070528},\n",
              "   'person': {'precision': 0.7133757961783439,\n",
              "    'recall': 0.8296296296296296,\n",
              "    'f_score': 0.7671232876712328},\n",
              "   'value': {'precision': 0.7067039106145251,\n",
              "    'recall': 0.7207977207977208,\n",
              "    'f_score': 0.7136812411847672},\n",
              "   'drug': {'precision': 0.7180043383947939,\n",
              "    'recall': 0.7471783295711061,\n",
              "    'f_score': 0.7323008849557522},\n",
              "   'temporal': {'precision': 0.49279538904899134,\n",
              "    'recall': 0.5757575757575758,\n",
              "    'f_score': 0.5310559006211181},\n",
              "   'measurement': {'precision': 0.5473372781065089,\n",
              "    'recall': 0.6379310344827587,\n",
              "    'f_score': 0.5891719745222931},\n",
              "   'procedure': {'precision': 0.5241157556270096,\n",
              "    'recall': 0.5207667731629393,\n",
              "    'f_score': 0.5224358974358974},\n",
              "   'observation': {'precision': 0.31683168316831684,\n",
              "    'recall': 0.1927710843373494,\n",
              "    'f_score': 0.2397003745318352},\n",
              "   'device': {'precision': 0.2903225806451613,\n",
              "    'recall': 0.391304347826087,\n",
              "    'f_score': 0.33333333333333337}},\n",
              "  'relax': {'condition': {'precision': 0.7956147220046985,\n",
              "    'recall': 0.9194570135746606,\n",
              "    'f_score': 0.8530646515533165},\n",
              "   'person': {'precision': 0.7261146496815286,\n",
              "    'recall': 0.8444444444444444,\n",
              "    'f_score': 0.7808219178082192},\n",
              "   'value': {'precision': 0.8547486033519553,\n",
              "    'recall': 0.8717948717948718,\n",
              "    'f_score': 0.8631875881523272},\n",
              "   'drug': {'precision': 0.841648590021692,\n",
              "    'recall': 0.8758465011286681,\n",
              "    'f_score': 0.8584070796460177},\n",
              "   'temporal': {'precision': 0.6368876080691642,\n",
              "    'recall': 0.7441077441077442,\n",
              "    'f_score': 0.686335403726708},\n",
              "   'measurement': {'precision': 0.7189349112426036,\n",
              "    'recall': 0.8379310344827586,\n",
              "    'f_score': 0.7738853503184714},\n",
              "   'procedure': {'precision': 0.6752411575562701,\n",
              "    'recall': 0.670926517571885,\n",
              "    'f_score': 0.6730769230769232},\n",
              "   'observation': {'precision': 0.5148514851485149,\n",
              "    'recall': 0.3132530120481928,\n",
              "    'f_score': 0.3895131086142322},\n",
              "   'device': {'precision': 0.41935483870967744,\n",
              "    'recall': 0.5652173913043478,\n",
              "    'f_score': 0.4814814814814815}}}}"
            ]
          },
          "metadata": {},
          "execution_count": 38
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "evaluator.save_evaluation('eval.json')"
      ],
      "metadata": {
        "id": "kbc_sKVo901C"
      },
      "execution_count": 40,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EY9TwGjjdeix"
      },
      "outputs": [],
      "source": [
        "results, results_strict,cr1,cr2 = compute_metrics_tr((annotated_sentences_first, dataset['test']['ner_tags']))"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "print(cr1)"
      ],
      "metadata": {
        "id": "t8cITFjd2947",
        "outputId": "363f869d-a417-49f3-85a2-d14fc4296b59",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "              precision    recall  f1-score   support\n",
            "\n",
            "   Condition       0.64      0.77      0.70      1105\n",
            "      Device       0.24      0.30      0.27        23\n",
            "        Drug       0.68      0.73      0.70       443\n",
            " Measurement       0.53      0.62      0.57       290\n",
            " Observation       0.30      0.18      0.23       166\n",
            "      Person       0.76      0.84      0.80       135\n",
            "   Procedure       0.46      0.49      0.48       313\n",
            "    Temporal       0.48      0.58      0.52       297\n",
            "       Value       0.65      0.70      0.68       351\n",
            "\n",
            "   micro avg       0.60      0.67      0.63      3123\n",
            "   macro avg       0.53      0.58      0.55      3123\n",
            "weighted avg       0.59      0.67      0.62      3123\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "print(cr2)"
      ],
      "metadata": {
        "id": "FOYL6Awc3Cp5",
        "outputId": "07c1fd7c-c56f-4adc-f3d5-5c4bf7f0ed7d",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "              precision    recall  f1-score   support\n",
            "\n",
            "   Condition       0.69      0.76      0.72      1104\n",
            "      Device       0.29      0.30      0.30        23\n",
            "        Drug       0.73      0.73      0.73       443\n",
            " Measurement       0.59      0.61      0.60       288\n",
            " Observation       0.40      0.17      0.24       166\n",
            "      Person       0.76      0.84      0.80       135\n",
            "   Procedure       0.53      0.49      0.51       311\n",
            "    Temporal       0.58      0.57      0.58       295\n",
            "       Value       0.70      0.72      0.71       345\n",
            "\n",
            "   micro avg       0.66      0.66      0.66      3110\n",
            "   macro avg       0.59      0.58      0.58      3110\n",
            "weighted avg       0.65      0.66      0.65      3110\n",
            "\n"
          ]
        }
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.13"
    },
    "colab": {
      "provenance": [],
      "gpuType": "T4",
      "include_colab_link": true
    },
    "accelerator": "GPU",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "99188e0dd50f4d5a874eaec3101ceab6": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_b289d32f0876481b97a080552549f7d1",
              "IPY_MODEL_499e568745f743aea5c0b2ac64ef160f",
              "IPY_MODEL_12eafe87e2b84edfa7d72b93361d3e6a"
            ],
            "layout": "IPY_MODEL_61a45966b4654a7db975880b0a193139"
          }
        },
        "b289d32f0876481b97a080552549f7d1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_95909a3c6f6245d1acd056f28bfe5ba8",
            "placeholder": "​",
            "style": "IPY_MODEL_feebbbdfbc3740bba039e95ddf31e4c1",
            "value": "Downloading builder script: "
          }
        },
        "499e568745f743aea5c0b2ac64ef160f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_1eed478ce7694c4cb7e2c998963cafcc",
            "max": 2471,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_677a609737064c3c86de6503530b6268",
            "value": 2471
          }
        },
        "12eafe87e2b84edfa7d72b93361d3e6a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_3736022fd7b5476dbfe03fee20d19f09",
            "placeholder": "​",
            "style": "IPY_MODEL_0846105c8d484656848ff41c571eb71f",
            "value": " 6.33k/? [00:00&lt;00:00, 501kB/s]"
          }
        },
        "61a45966b4654a7db975880b0a193139": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "95909a3c6f6245d1acd056f28bfe5ba8": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "feebbbdfbc3740bba039e95ddf31e4c1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "1eed478ce7694c4cb7e2c998963cafcc": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "677a609737064c3c86de6503530b6268": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "3736022fd7b5476dbfe03fee20d19f09": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "0846105c8d484656848ff41c571eb71f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "aae03eda481e488eaf1ef5b0610cdc0b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "VBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "VBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "VBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_13acbfe9fc7a4ff4a6ec0a8be56c4f62",
              "IPY_MODEL_e7bcd126a8384f94a027e130d09f2346",
              "IPY_MODEL_ee0ac8c2fac44c02b28844bb5bc6822a",
              "IPY_MODEL_d1444855a9cc4726a37cf9f184be8409"
            ],
            "layout": "IPY_MODEL_d2605acf97fa45ad8b157868a831e4f7"
          }
        },
        "034ac69beaf14927bb2800c7e848ecbb": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c116ed771b584a1398d9200fd3a8eba0",
            "placeholder": "​",
            "style": "IPY_MODEL_f3a413d8c726471295e252f3ea122f31",
            "value": "<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.svg\nalt='Hugging Face'> <br> Copy a token from <a\nhref=\"https://huggingface.co/settings/tokens\" target=\"_blank\">your Hugging Face\ntokens page</a> and paste it below. <br> Immediately click login after copying\nyour token or it might be stored in plain text in this notebook file. </center>"
          }
        },
        "0c7e1ad3261b4ff4876a9a8654fc9464": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "PasswordModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "PasswordModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "PasswordView",
            "continuous_update": true,
            "description": "Token:",
            "description_tooltip": null,
            "disabled": false,
            "layout": "IPY_MODEL_2e4c44c84e464804b1c757f88e46d506",
            "placeholder": "​",
            "style": "IPY_MODEL_719dfc3140194a839f84e6efe01e01b0",
            "value": ""
          }
        },
        "306bc801a94541999e50632fbc3f4c0f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "CheckboxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "CheckboxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "CheckboxView",
            "description": "Add token as git credential?",
            "description_tooltip": null,
            "disabled": false,
            "indent": true,
            "layout": "IPY_MODEL_7e0e3e6216db4347a7de8f4553dd6282",
            "style": "IPY_MODEL_7c10261a70ee4e7184524cc2590a28c1",
            "value": true
          }
        },
        "36dbf510f6fc446383291658d394066a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ButtonModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ButtonModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ButtonView",
            "button_style": "",
            "description": "Login",
            "disabled": false,
            "icon": "",
            "layout": "IPY_MODEL_5b93a9aedb2f49d88e5e1d9cdb13c751",
            "style": "IPY_MODEL_5ad14d3ac4614dcc903d60dcb4edbb88",
            "tooltip": ""
          }
        },
        "77affba1207543b483febb7acb086298": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c218b57c46fc4f5eb9260fd08c5a94c8",
            "placeholder": "​",
            "style": "IPY_MODEL_e377b67ff2a948e3be138334ed3de886",
            "value": "\n<b>Pro Tip:</b> If you don't already have one, you can create a dedicated\n'notebooks' token with 'write' access, that you can then easily reuse for all\nnotebooks. </center>"
          }
        },
        "d2605acf97fa45ad8b157868a831e4f7": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": "center",
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "flex",
            "flex": null,
            "flex_flow": "column",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "50%"
          }
        },
        "c116ed771b584a1398d9200fd3a8eba0": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f3a413d8c726471295e252f3ea122f31": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "2e4c44c84e464804b1c757f88e46d506": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "719dfc3140194a839f84e6efe01e01b0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "7e0e3e6216db4347a7de8f4553dd6282": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7c10261a70ee4e7184524cc2590a28c1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "5b93a9aedb2f49d88e5e1d9cdb13c751": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "5ad14d3ac4614dcc903d60dcb4edbb88": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ButtonStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ButtonStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "button_color": null,
            "font_weight": ""
          }
        },
        "c218b57c46fc4f5eb9260fd08c5a94c8": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e377b67ff2a948e3be138334ed3de886": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "d9a4367195b94edc9d1ba97d7b0a2eb8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "LabelModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "LabelModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "LabelView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_aa9ffe65a42643c69f553f283cdf6772",
            "placeholder": "​",
            "style": "IPY_MODEL_8a95f6bbd81148519ec4639d8b0f4db6",
            "value": "Connecting..."
          }
        },
        "aa9ffe65a42643c69f553f283cdf6772": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8a95f6bbd81148519ec4639d8b0f4db6": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "13acbfe9fc7a4ff4a6ec0a8be56c4f62": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "LabelModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "LabelModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "LabelView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_baa7f4149cb446078d05574550f33cda",
            "placeholder": "​",
            "style": "IPY_MODEL_002b9e99a05c46749a052c5770f280a5",
            "value": "Token is valid (permission: write)."
          }
        },
        "e7bcd126a8384f94a027e130d09f2346": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "LabelModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "LabelModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "LabelView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_33a233aa855f4356960aa336d361bfdb",
            "placeholder": "​",
            "style": "IPY_MODEL_b111ac45f0d648a8b89e1bf3d26dc354",
            "value": "Your token has been saved in your configured git credential helpers (store)."
          }
        },
        "ee0ac8c2fac44c02b28844bb5bc6822a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "LabelModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "LabelModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "LabelView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_7c0010ffc2b84b7cb1c2e943e849a67b",
            "placeholder": "​",
            "style": "IPY_MODEL_4a0ecc50b9e94408b714c578a0786f03",
            "value": "Your token has been saved to /root/.cache/huggingface/token"
          }
        },
        "d1444855a9cc4726a37cf9f184be8409": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "LabelModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "LabelModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "LabelView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_f67c37bb203c4775b5997e50af97fae7",
            "placeholder": "​",
            "style": "IPY_MODEL_8fe86ac5511c400d81c52bd335c96859",
            "value": "Login successful"
          }
        },
        "baa7f4149cb446078d05574550f33cda": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "002b9e99a05c46749a052c5770f280a5": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "33a233aa855f4356960aa336d361bfdb": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b111ac45f0d648a8b89e1bf3d26dc354": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "7c0010ffc2b84b7cb1c2e943e849a67b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "4a0ecc50b9e94408b714c578a0786f03": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "f67c37bb203c4775b5997e50af97fae7": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8fe86ac5511c400d81c52bd335c96859": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "c134904ded5f47aabca0dfa5b67049db": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_3c1d79df90544bdd93873579d38116a4",
              "IPY_MODEL_3806b75490704196afde5977db7dca89",
              "IPY_MODEL_7665c9c809964b6eb6b4a9720c5d2dac"
            ],
            "layout": "IPY_MODEL_1785e46465ec42ccbb933604963d5dd6"
          }
        },
        "3c1d79df90544bdd93873579d38116a4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_7b7dd023dbc7472998c6aac7eee1490f",
            "placeholder": "​",
            "style": "IPY_MODEL_62a0e81ff19a4162a67e5884ce678f2c",
            "value": "Map: 100%"
          }
        },
        "3806b75490704196afde5977db7dca89": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_24a85f5bd83441b4943299be2d145aff",
            "max": 1307,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_18eaa2134e8e4edeb41340cebbeb8572",
            "value": 1307
          }
        },
        "7665c9c809964b6eb6b4a9720c5d2dac": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_68e1a741d3ec4fc3a26d759ecc357ac6",
            "placeholder": "​",
            "style": "IPY_MODEL_de3e98c73f9547bebdad9b4bcb21d9a1",
            "value": " 1307/1307 [00:01&lt;00:00, 860.18 examples/s]"
          }
        },
        "1785e46465ec42ccbb933604963d5dd6": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "7b7dd023dbc7472998c6aac7eee1490f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "62a0e81ff19a4162a67e5884ce678f2c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "24a85f5bd83441b4943299be2d145aff": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "18eaa2134e8e4edeb41340cebbeb8572": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "68e1a741d3ec4fc3a26d759ecc357ac6": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "de3e98c73f9547bebdad9b4bcb21d9a1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}