[c0f169]: / notebooks / 12_Model_9__FineTune_Pretrained_DistilBert-Base-Uncased.ipynb

Download this file

995 lines (994 with data), 39.7 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "dd00dccc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from tensorflow.keras.preprocessing.sequence import pad_sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "49a7c1a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TFAutoModelForTokenClassification, AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1ad519f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from transformers import AutoTokenizer, TFAutoModelForTokenClassification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "759b4437",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append('../')\n",
    "from config import entity_to_acronyms, acronyms_to_entities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "77d21eb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_dir = '../models'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d27bee4",
   "metadata": {},
   "source": [
    "## Model Definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "65344d21",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "All PyTorch model weights were used when initializing TFDistilBertForTokenClassification.\n",
      "\n",
      "All the weights of TFDistilBertForTokenClassification were initialized from the PyTorch model.\n",
      "If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertForTokenClassification for predictions without further training.\n"
     ]
    }
   ],
   "source": [
    "model_name = \"d4data/biomedical-ner-all\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = TFAutoModelForTokenClassification.from_pretrained(model_name, from_pt=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0b1c9ed",
   "metadata": {},
   "source": [
    "## Define training parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0c4fccc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 32\n",
    "NUM_EPOCHS = 2\n",
    "LEARNING_RATE = 1e-5"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8536a1ce",
   "metadata": {},
   "source": [
    "## Prepare the dataset to fine tune the Pretrained DistilBERT base uncased"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "54ba3f65",
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_LENGTH = 200"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5d7f2098",
   "metadata": {},
   "outputs": [],
   "source": [
    "bio_files_dir = '../data/bio_data_files'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b4d20de3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "\n",
    "def read_file(file_path):\n",
    "    \"\"\"Helper function to read data from a single file.\"\"\"\n",
    "    with open(file_path, \"r\", encoding=\"utf-8\") as f:\n",
    "        lines = f.readlines()\n",
    "        sentences = []\n",
    "        labels = []\n",
    "        sentence = []\n",
    "        label = []\n",
    "        for line in lines:\n",
    "            if line == '\\n':\n",
    "                if sentence:\n",
    "                    sentences.append(\" \".join(sentence))\n",
    "                    labels.append(\" \".join(label))\n",
    "                    sentence = []\n",
    "                    label = []\n",
    "            else:\n",
    "                word, tag = line.strip().split(\"\\t\")\n",
    "                sentence.append(word)\n",
    "                if tag != 'O':\n",
    "                    tag = tag[:2] + acronyms_to_entities[tag[2:]]\n",
    "                label.append(tag)\n",
    "        if sentence:\n",
    "            sentences.append(\" \".join(sentence))\n",
    "            labels.append(\" \".join(label))\n",
    "        return sentences, labels\n",
    "\n",
    "def prepare_data(directory_path):\n",
    "    \"\"\"Read data from all files in the given directory and prepare for fine-tuning.\"\"\"\n",
    "    train_sentences = []\n",
    "    train_labels = []\n",
    "    val_sentences = []\n",
    "    val_labels = []\n",
    "    test_sentences = []\n",
    "    test_labels = []\n",
    "    for i, filename in enumerate(os.listdir(directory_path)):\n",
    "        file_path = os.path.join(directory_path, filename)\n",
    "        sentences, labels = read_file(file_path)\n",
    "        if i % 5 == 0:  # 20% of data for validation\n",
    "            val_sentences.extend(sentences)\n",
    "            val_labels.extend(labels)\n",
    "        elif i % 5 == 1:  # 20% of data for testing\n",
    "            test_sentences.extend(sentences)\n",
    "            test_labels.extend(labels)\n",
    "        else:  # 60% of data for training\n",
    "            train_sentences.extend(sentences)\n",
    "            train_labels.extend(labels)\n",
    "            \n",
    "    train_data = {\"input_ids\": tokenizer(train_sentences, truncation=True, max_length=MAX_LENGTH, padding='max_length', return_tensors=\"tf\")[\"input_ids\"],\n",
    "                  \"attention_mask\": tokenizer(train_sentences, truncation=True, max_length=MAX_LENGTH, padding='max_length', return_tensors=\"tf\")[\"attention_mask\"],\n",
    "                  \"labels\": np.array([[model.config.label2id[token] for token in label.split()] for label in train_labels], dtype='object')}\n",
    "    \n",
    "    val_data = {\"input_ids\": tokenizer(val_sentences, truncation=True, max_length=MAX_LENGTH, padding='max_length', return_tensors=\"tf\")[\"input_ids\"],\n",
    "                \"attention_mask\": tokenizer(val_sentences, truncation=True, max_length=MAX_LENGTH, padding='max_length', return_tensors=\"tf\")[\"attention_mask\"],\n",
    "                \"labels\": np.array([[model.config.label2id[token] for token in label.split()] for label in val_labels], dtype='object')}\n",
    "    \n",
    "    test_data = {\"input_ids\": tokenizer(test_sentences, truncation=True, max_length=MAX_LENGTH, padding='max_length', return_tensors=\"tf\")[\"input_ids\"],\n",
    "                 \"attention_mask\": tokenizer(test_sentences, truncation=True, max_length=MAX_LENGTH, padding='max_length', return_tensors=\"tf\")[\"attention_mask\"],\n",
    "                 \"labels\": np.array([[model.config.label2id[token] for token in label.split()] for label in test_labels], dtype='object')}\n",
    "\n",
    "    \n",
    "    # assuming train_labels is a list of lists of integer-encoded labels\n",
    "    padded_labels = tf.keras.preprocessing.sequence.pad_sequences(\n",
    "        train_data['labels'],\n",
    "        maxlen=MAX_LENGTH,\n",
    "        padding='post',\n",
    "        truncating='post',\n",
    "        value=0  # or any other value to use for padding\n",
    "    )\n",
    "\n",
    "    # Convert to tensor\n",
    "    train_data['labels'] = tf.convert_to_tensor(padded_labels)\n",
    "    \n",
    "    # assuming train_labels is a list of lists of integer-encoded labels\n",
    "    padded_labels = tf.keras.preprocessing.sequence.pad_sequences(\n",
    "        val_data['labels'],\n",
    "        maxlen=MAX_LENGTH,\n",
    "        padding='post',\n",
    "        truncating='post',\n",
    "        value=0  # or any other value to use for padding\n",
    "    )\n",
    "\n",
    "    # Convert to tensor\n",
    "    val_data['labels'] = tf.convert_to_tensor(padded_labels)\n",
    "    \n",
    "    # assuming train_labels is a list of lists of integer-encoded labels\n",
    "    padded_labels = tf.keras.preprocessing.sequence.pad_sequences(\n",
    "        test_data['labels'],\n",
    "        maxlen=MAX_LENGTH,\n",
    "        padding='post',\n",
    "        truncating='post',\n",
    "        value=0  # or any other value to use for padding\n",
    "    )\n",
    "\n",
    "    # Convert to tensor\n",
    "    test_data['labels'] = tf.convert_to_tensor(padded_labels)\n",
    "    \n",
    "    return train_data, val_data, test_data\n",
    "\n",
    "train_data, val_data, test_data =  prepare_data(bio_files_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e7223979",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TRAINING DATA\n",
      "The shape of input ids tensor of train data is (2696, 200)\n",
      "The shape of attention masks tensor of train data is (2696, 200)\n",
      "The shape of labels tensor of train data is (2696, 200)\n",
      "\n",
      "VALIDATION DATA\n",
      "The shape of input ids tensor of validation data is (907, 200)\n",
      "The shape of attention masks tensor of validation data is (907, 200)\n",
      "The shape of labels tensor of validation data is (907, 200)\n",
      "\n",
      "TEST DATA\n",
      "The shape of input ids tensor of test data is (938, 200)\n",
      "The shape of attention masks tensor of test data is (938, 200)\n",
      "The shape of labels tensor of test data is (938, 200)\n"
     ]
    }
   ],
   "source": [
    "\n",
    "print(\"TRAINING DATA\")\n",
    "print(f\"The shape of input ids tensor of train data is {train_data['input_ids'].shape}\")\n",
    "print(f\"The shape of attention masks tensor of train data is {train_data['attention_mask'].shape}\")\n",
    "print(f\"The shape of labels tensor of train data is {train_data['labels'].shape}\")\n",
    "\n",
    "print(\"\\nVALIDATION DATA\")\n",
    "print(f\"The shape of input ids tensor of validation data is {val_data['input_ids'].shape}\")\n",
    "print(f\"The shape of attention masks tensor of validation data is {val_data['attention_mask'].shape}\")\n",
    "print(f\"The shape of labels tensor of validation data is {val_data['labels'].shape}\")\n",
    "\n",
    "print(\"\\nTEST DATA\")\n",
    "print(f\"The shape of input ids tensor of test data is {test_data['input_ids'].shape}\")\n",
    "print(f\"The shape of attention masks tensor of test data is {test_data['attention_mask'].shape}\")\n",
    "print(f\"The shape of labels tensor of test data is {test_data['labels'].shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8da8e675",
   "metadata": {},
   "source": [
    "## Creating Tensorflow Datasets from preprocessed data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "9afac0ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create TensorFlow datasets\n",
    "train_dataset = tf.data.Dataset.from_tensor_slices((train_data[\"input_ids\"], train_data[\"attention_mask\"], train_data[\"labels\"])).batch(BATCH_SIZE)\n",
    "val_dataset = tf.data.Dataset.from_tensor_slices((val_data[\"input_ids\"], val_data[\"attention_mask\"], val_data[\"labels\"])).batch(BATCH_SIZE)\n",
    "test_dataset = tf.data.Dataset.from_tensor_slices((test_data[\"input_ids\"], test_data[\"attention_mask\"], test_data[\"labels\"])).batch(BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "9d9b141b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TRAINING DATASET\n",
      "Number of batches in train dataset: 85\n",
      "Shape of the batches: (TensorSpec(shape=(None, 200), dtype=tf.int32, name=None), TensorSpec(shape=(None, 200), dtype=tf.int32, name=None), TensorSpec(shape=(None, 200), dtype=tf.int32, name=None))\n",
      "\n",
      "VALIDATION DATASET\n",
      "Number of batches in validation dataset: 29\n",
      "Shape of the batches: (TensorSpec(shape=(None, 200), dtype=tf.int32, name=None), TensorSpec(shape=(None, 200), dtype=tf.int32, name=None), TensorSpec(shape=(None, 200), dtype=tf.int32, name=None))\n",
      "\n",
      "TEST DATASET\n",
      "Number of batches in test dataset: 30\n",
      "Shape of the batches: (TensorSpec(shape=(None, 200), dtype=tf.int32, name=None), TensorSpec(shape=(None, 200), dtype=tf.int32, name=None), TensorSpec(shape=(None, 200), dtype=tf.int32, name=None))\n"
     ]
    }
   ],
   "source": [
    "print(\"TRAINING DATASET\")\n",
    "print(f\"Number of batches in train dataset: {len(train_dataset)}\")\n",
    "print(f\"Shape of the batches: {train_dataset.element_spec}\", )\n",
    "\n",
    "print(\"\\nVALIDATION DATASET\")\n",
    "print(f\"Number of batches in validation dataset: {len(val_dataset)}\")\n",
    "print(f\"Shape of the batches: {val_dataset.element_spec}\", )\n",
    "\n",
    "print(\"\\nTEST DATASET\")\n",
    "print(f\"Number of batches in test dataset: {len(test_dataset)}\")\n",
    "print(f\"Shape of the batches: {test_dataset.element_spec}\", )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "a3d031cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the optimizer, loss function and metrics for training\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)\n",
    "loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
    "metrics = [tf.keras.metrics.SparseCategoricalAccuracy(\"accuracy\")]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d399bd7b",
   "metadata": {},
   "source": [
    "## Compile the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "3e4d2350",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(optimizer=optimizer, loss=loss, metrics=metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca040c6b",
   "metadata": {},
   "source": [
    "## Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c30d6938",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-04-09 11:10:51.904811: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "85/85 [==============================] - 487s 6s/step - loss: 0.9095 - accuracy: 0.9121 - val_loss: 0.5841 - val_accuracy: 0.9479\n",
      "Epoch 2/2\n",
      "85/85 [==============================] - 530s 6s/step - loss: 0.4405 - accuracy: 0.9464 - val_loss: 0.3286 - val_accuracy: 0.9486\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x2c408c8e0>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "history = model.fit(\n",
    "    x = [train_data['input_ids'], train_data['attention_mask']],\n",
    "    y = train_data['labels'],\n",
    "    validation_data=([val_data['input_ids'], val_data['attention_mask']], val_data['labels']), \n",
    "    epochs=NUM_EPOCHS\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2727a82",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.plot(history.history['accuracy'])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "c71f7ac0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "85/85 [==============================] - 131s 2s/step - loss: 0.3379 - accuracy: 0.9468\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.3379139304161072, 0.9468490481376648]"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Evaluate the model on the train data\n",
    "model.evaluate(\n",
    "    x = [train_data['input_ids'], train_data['attention_mask']],\n",
    "    y = train_data['labels']\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "fc2fd1dc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "30/30 [==============================] - 45s 1s/step - loss: 0.3349 - accuracy: 0.9477\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.3348983824253082, 0.9477131962776184]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Evaluate the model on the train data\n",
    "model.evaluate(\n",
    "    x = [test_data['input_ids'], test_data['attention_mask']],\n",
    "    y = test_data['labels']\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ce210bb",
   "metadata": {},
   "source": [
    "## Save the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "cf1b7c7c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.regularization.dropout.Dropout object at 0x2d2480070>, because it is not built.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.regularization.dropout.Dropout object at 0x2d2480070>, because it is not built.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.regularization.dropout.Dropout object at 0x2d2463670>, because it is not built.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.regularization.dropout.Dropout object at 0x2d2463670>, because it is not built.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.regularization.dropout.Dropout object at 0x31714a610>, because it is not built.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.regularization.dropout.Dropout object at 0x31714a610>, because it is not built.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.regularization.dropout.Dropout object at 0x3162851f0>, because it is not built.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.regularization.dropout.Dropout object at 0x3162851f0>, because it is not built.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.regularization.dropout.Dropout object at 0x322fe5250>, because it is not built.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.regularization.dropout.Dropout object at 0x322fe5250>, because it is not built.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.regularization.dropout.Dropout object at 0x31d0f2ca0>, because it is not built.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.regularization.dropout.Dropout object at 0x31d0f2ca0>, because it is not built.\n",
      "WARNING:absl:Found untraced functions such as embeddings_layer_call_fn, embeddings_layer_call_and_return_conditional_losses, transformer_layer_call_fn, transformer_layer_call_and_return_conditional_losses, LayerNorm_layer_call_fn while saving (showing 5 of 164). These functions will not be directly callable after loading.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: ../models/model_9/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: ../models/model_9/assets\n"
     ]
    }
   ],
   "source": [
    "model.save(os.path.join(model_dir, 'model_9'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "875dcc59",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the model\n",
    "loaded_model = tf.keras.models.load_model(os.path.join(model_dir, 'model_9'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "97c36209",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract the model from the Loader object\n",
    "loaded_model_1 = loaded_model.signatures['serving_default']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cdce15f8",
   "metadata": {},
   "source": [
    "## Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "29e44363",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"tf_distil_bert_for_token_classification_1\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " distilbert (Custom>TFDistil  multiple                 66362880  \n",
      " BertMainLayer)                                                  \n",
      "                                                                 \n",
      " dropout_58 (Dropout)        multiple                  0         \n",
      "                                                                 \n",
      " classifier (Dense)          multiple                  64596     \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 66,427,476\n",
      "Trainable params: 66,427,476\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "loaded_model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "0d5d8e91",
   "metadata": {},
   "outputs": [],
   "source": [
    "import spacy\n",
    "from spacy import displacy\n",
    "\n",
    "def display_pred(text, entities):\n",
    "    nlp = spacy.load(\"en_core_web_sm\", disable=['ner'])\n",
    "    # Generate the entities in Spacy format\n",
    "    doc = nlp(text)\n",
    "    # Add the predicted named entities to the Doc object\n",
    "    for start, end, label in entities:\n",
    "        span = doc.char_span(start, end, label=label)\n",
    "        if span is not None:\n",
    "            doc.ents += tuple([span])\n",
    "\n",
    "    colors = {\"Activity\": \"#f9d5e5\",\n",
    "              \"Administration\": \"#f7a399\",\n",
    "              \"Age\": \"#f6c3d0\",\n",
    "              \"Area\": \"#fde2e4\",\n",
    "              \"Biological_attribute\": \"#d5f5e3\",\n",
    "              \"Biological_structure\": \"#9ddfd3\",\n",
    "              \"Clinical_event\": \"#77c5d5\",\n",
    "              \"Color\": \"#a0ced9\",\n",
    "              \"Coreference\": \"#e3b5a4\",\n",
    "              \"Date\": \"#f1f0d2\",\n",
    "              \"Detailed_description\": \"#ffb347\",\n",
    "              \"Diagnostic_procedure\": \"#c5b4e3\",\n",
    "              \"Disease_disorder\": \"#c4b7ea\",\n",
    "              \"Distance\": \"#bde0fe\",\n",
    "              \"Dosage\": \"#b9e8d8\",\n",
    "              \"Duration\": \"#ffdfba\",\n",
    "              \"Family_history\": \"#e6ccb2\",\n",
    "              \"Frequency\": \"#e9d8a6\",\n",
    "              \"Height\": \"#f2eecb\",\n",
    "              \"History\": \"#e2f0cb\",\n",
    "              \"Lab_value\": \"#f4b3c2\",\n",
    "              \"Mass\": \"#f4c4c3\",\n",
    "              \"Medication\": \"#f9d5e5\",\n",
    "              \"Nonbiological_location\": \"#f7a399\",\n",
    "              \"Occupation\": \"#f6c3d0\",\n",
    "              \"Other_entity\": \"#d5f5e3\",\n",
    "              \"Other_event\": \"#9ddfd3\",\n",
    "              \"Outcome\": \"#77c5d5\",\n",
    "              \"Personal_background\": \"#a0ced9\",\n",
    "              \"Qualitative_concept\": \"#e3b5a4\",\n",
    "              \"Quantitative_concept\": \"#f1f0d2\",\n",
    "              \"Severity\": \"#ffb347\",\n",
    "              \"Sex\": \"#c5b4e3\",\n",
    "              \"Shape\": \"#c4b7ea\",\n",
    "              \"Sign_symptom\": \"#bde0fe\",\n",
    "              \"Subject\": \"#b9e8d8\",\n",
    "              \"Texture\": \"#ffdfba\",\n",
    "              \"Therapeutic_procedure\": \"#e6ccb2\",\n",
    "              \"Time\": \"#e9d8a6\",\n",
    "              \"Volume\": \"#f2eecb\",\n",
    "              \"Weight\": \"#e2f0cb\"}\n",
    "    options = {\"compact\": True, \"bg\": \"#F8F8F8\",\n",
    "               \"ents\": list(colors.keys()),\n",
    "               \"colors\": colors}\n",
    "\n",
    "    # Generate the HTML visualization\n",
    "    html = displacy.render(doc, style=\"ent\", options=options)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "29a7371e",
   "metadata": {},
   "outputs": [],
   "source": [
    "text = \"A 57-year-old man presented to the emergency department with a 2-day history of worsening shortness of breath and chest pain. He reported no recent travel or sick contacts. His medical history was significant for hypertension, dyslipidemia, and type 2 diabetes mellitus. On examination, he was tachycardic and tachypneic, with oxygen saturation of 88% on room air. Chest radiography revealed bilateral opacities consistent with pulmonary edema. The patient was admitted to the intensive care unit for management of acute decompensated heart failure. He was started on intravenous diuretics and inotropic support with dobutamine. Over the next several days, his symptoms improved and he was discharged to home with instructions to follow up with his primary care provider in 1 week.\"\n",
    "# Tokenize the input sentence\n",
    "encoded = tokenizer.encode_plus(text, return_tensors=\"tf\", return_offsets_mapping=True)\n",
    "\n",
    "input_ids = encoded['input_ids']\n",
    "attention_mask = encoded['attention_mask']\n",
    "\n",
    "inputs = {\n",
    "    'input_ids': input_ids,\n",
    "    'attention_mask': attention_mask\n",
    "}\n",
    "\n",
    "offsets = encoded['offset_mapping'][0].numpy()\n",
    "\n",
    "\n",
    "# Get the model predictions\n",
    "outputs = loaded_model_1(input_ids=input_ids, attention_mask=attention_mask)['logits']\n",
    "predictions = tf.argmax(outputs, axis=-1)\n",
    "\n",
    "# # Convert the predicted label ids to label names\n",
    "\n",
    "predicted_labels = [model.config.id2label[prediction] for prediction in predictions[0].numpy()]\n",
    "\n",
    "entities = []\n",
    "prev_tag = None\n",
    "prev_end = -1\n",
    "\n",
    "for start_end, label in zip(offsets, predicted_labels):\n",
    "    start = start_end[0]\n",
    "    end = start_end[1]\n",
    "    if label != 'O':\n",
    "        tag = label[2:]\n",
    "        if len(entities) > 0:\n",
    "                prev_end = entities[-1][1]\n",
    "                prev_start = entities[-1][0]\n",
    "                prev_tag = entities[-1][2]\n",
    "        if prev_tag == tag and (prev_end == start or  prev_end+1 == start):\n",
    "            entities[-1] = (prev_start, end, tag)\n",
    "        else:\n",
    "            entities.append((start, end, tag))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "79ab6298",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<span class=\"tex2jax_ignore\"><div class=\"entities\" style=\"line-height: 2.5; direction: ltr\">A \n",
       "<mark class=\"entity\" style=\"background: #f6c3d0; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    57-year-old\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Age</span>\n",
       "</mark>\n",
       " \n",
       "<mark class=\"entity\" style=\"background: #c5b4e3; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    man\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Sex</span>\n",
       "</mark>\n",
       " \n",
       "<mark class=\"entity\" style=\"background: #77c5d5; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    presented\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Clinical_event</span>\n",
       "</mark>\n",
       " to the \n",
       "<mark class=\"entity\" style=\"background: #f7a399; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    emergency department\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Nonbiological_location</span>\n",
       "</mark>\n",
       " with a \n",
       "<mark class=\"entity\" style=\"background: #ffdfba; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    2-day\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Duration</span>\n",
       "</mark>\n",
       " history of worsening \n",
       "<mark class=\"entity\" style=\"background: #bde0fe; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    shortness of breath\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Sign_symptom</span>\n",
       "</mark>\n",
       " and \n",
       "<mark class=\"entity\" style=\"background: #9ddfd3; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    chest\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Biological_structure</span>\n",
       "</mark>\n",
       " pain. He reported \n",
       "<mark class=\"entity\" style=\"background: #e2f0cb; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    no recent travel\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">History</span>\n",
       "</mark>\n",
       " or sick contacts. His medical history was significant for hypertension, dyslipidemia, and \n",
       "<mark class=\"entity\" style=\"background: #e2f0cb; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    type 2 diabetes\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">History</span>\n",
       "</mark>\n",
       " mellitus. On examination, he was tachycardic and tachypneic, with oxygen saturation of \n",
       "<mark class=\"entity\" style=\"background: #f4b3c2; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    88\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Lab_value</span>\n",
       "</mark>\n",
       "% on \n",
       "<mark class=\"entity\" style=\"background: #ffb347; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    room\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Detailed_description</span>\n",
       "</mark>\n",
       " air. \n",
       "<mark class=\"entity\" style=\"background: #9ddfd3; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    Chest\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Biological_structure</span>\n",
       "</mark>\n",
       " \n",
       "<mark class=\"entity\" style=\"background: #c5b4e3; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    radiography\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Diagnostic_procedure</span>\n",
       "</mark>\n",
       " revealed \n",
       "<mark class=\"entity\" style=\"background: #ffb347; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    bilateral\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Detailed_description</span>\n",
       "</mark>\n",
       " \n",
       "<mark class=\"entity\" style=\"background: #bde0fe; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    opacities\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Sign_symptom</span>\n",
       "</mark>\n",
       " consistent with \n",
       "<mark class=\"entity\" style=\"background: #9ddfd3; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    pulmonary\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Biological_structure</span>\n",
       "</mark>\n",
       " edema. The patient was \n",
       "<mark class=\"entity\" style=\"background: #77c5d5; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    admitted\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Clinical_event</span>\n",
       "</mark>\n",
       " to the \n",
       "<mark class=\"entity\" style=\"background: #f7a399; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    intensive care unit\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Nonbiological_location</span>\n",
       "</mark>\n",
       " for management of acute decompensated heart failure. He was started on intravenous diuretics and \n",
       "<mark class=\"entity\" style=\"background: #f9d5e5; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    inotropic support\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Medication</span>\n",
       "</mark>\n",
       " with dobutamine. \n",
       "<mark class=\"entity\" style=\"background: #ffdfba; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    Over the next several days\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Duration</span>\n",
       "</mark>\n",
       ", his \n",
       "<mark class=\"entity\" style=\"background: #bde0fe; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    symptoms\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Sign_symptom</span>\n",
       "</mark>\n",
       " improved and he was \n",
       "<mark class=\"entity\" style=\"background: #77c5d5; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    discharged\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Clinical_event</span>\n",
       "</mark>\n",
       " to \n",
       "<mark class=\"entity\" style=\"background: #f7a399; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    home\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Nonbiological_location</span>\n",
       "</mark>\n",
       " with instructions to \n",
       "<mark class=\"entity\" style=\"background: #77c5d5; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    follow\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Clinical_event</span>\n",
       "</mark>\n",
       " up with his \n",
       "<mark class=\"entity\" style=\"background: #f7a399; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
       "    primary care provider\n",
       "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">Nonbiological_location</span>\n",
       "</mark>\n",
       " in 1 week.</div></span>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "display_pred(text, entities)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21049c58",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}