[ce2cbf]: / EL_pubmedbert.ipynb

Download this file

1329 lines (1328 with data), 44.6 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "138778c4",
   "metadata": {
    "id": "138778c4"
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "from sklearn import metrics\n",
    "from scipy.spatial.distance import cdist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "dTSILRD7hIHG",
   "metadata": {
    "id": "dTSILRD7hIHG"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: transformers in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (4.24.0)\n",
      "Requirement already satisfied: filelock in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers) (3.6.0)\n",
      "Requirement already satisfied: packaging>=20.0 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers) (21.3)\n",
      "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers) (0.13.1)\n",
      "Requirement already satisfied: huggingface-hub<1.0,>=0.10.0 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers) (0.10.1)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers) (2022.7.25)\n",
      "Requirement already satisfied: numpy>=1.17 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers) (1.21.5)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers) (6.0)\n",
      "Requirement already satisfied: tqdm>=4.27 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers) (4.64.1)\n",
      "Requirement already satisfied: requests in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers) (2.27.1)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.10.0->transformers) (4.1.1)\n",
      "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from packaging>=20.0->transformers) (2.4.7)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from requests->transformers) (3.3)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from requests->transformers) (2021.10.8)\n",
      "Requirement already satisfied: charset-normalizer~=2.0.0 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from requests->transformers) (2.0.4)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from requests->transformers) (1.26.9)\n"
     ]
    }
   ],
   "source": [
    "!pip install transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "70512d4e",
   "metadata": {
    "id": "70512d4e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: parallelformers in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (1.2.7)\n",
      "Requirement already satisfied: dacite in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from parallelformers) (1.6.0)\n",
      "Requirement already satisfied: torch in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from parallelformers) (1.11.0)\n",
      "Requirement already satisfied: transformers>=4.2 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from parallelformers) (4.24.0)\n",
      "Requirement already satisfied: requests in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers>=4.2->parallelformers) (2.27.1)\n",
      "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers>=4.2->parallelformers) (0.13.1)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers>=4.2->parallelformers) (2022.7.25)\n",
      "Requirement already satisfied: huggingface-hub<1.0,>=0.10.0 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers>=4.2->parallelformers) (0.10.1)\n",
      "Requirement already satisfied: tqdm>=4.27 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers>=4.2->parallelformers) (4.64.1)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers>=4.2->parallelformers) (6.0)\n",
      "Requirement already satisfied: packaging>=20.0 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers>=4.2->parallelformers) (21.3)\n",
      "Requirement already satisfied: filelock in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers>=4.2->parallelformers) (3.6.0)\n",
      "Requirement already satisfied: numpy>=1.17 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers>=4.2->parallelformers) (1.21.5)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.10.0->transformers>=4.2->parallelformers) (4.1.1)\n",
      "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from packaging>=20.0->transformers>=4.2->parallelformers) (2.4.7)\n",
      "Requirement already satisfied: charset-normalizer~=2.0.0 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from requests->transformers>=4.2->parallelformers) (2.0.4)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from requests->transformers>=4.2->parallelformers) (3.3)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from requests->transformers>=4.2->parallelformers) (1.26.9)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from requests->transformers>=4.2->parallelformers) (2021.10.8)\n"
     ]
    }
   ],
   "source": [
    "!pip install parallelformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6b2ace9d",
   "metadata": {
    "id": "6b2ace9d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: faiss-gpu in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (1.7.2)\r\n"
     ]
    }
   ],
   "source": [
    "!pip install faiss-gpu"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a0d2481",
   "metadata": {
    "id": "5a0d2481"
   },
   "source": [
    "Download snomed term-concept file from UMLS website"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ea498c9d",
   "metadata": {
    "id": "ea498c9d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1569232\n"
     ]
    }
   ],
   "source": [
    "snomed_csv = pd.read_csv(\"sct2_Description_Snapshot-en_INT_20220831.txt\", delimiter=\"\\t\")\n",
    "print(len(snomed_csv))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "DkqmgM5uxfob",
   "metadata": {
    "id": "DkqmgM5uxfob"
   },
   "outputs": [],
   "source": [
    "# from google.colab import drive\n",
    "# drive.mount('/content/drive')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d811cd3c",
   "metadata": {
    "id": "d811cd3c"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['id', 'effectiveTime', 'active', 'moduleId', 'conceptId',\n",
       "       'languageCode', 'typeId', 'term', 'caseSignificanceId'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "snomed_csv.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3327f0d9",
   "metadata": {
    "id": "3327f0d9"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>effectiveTime</th>\n",
       "      <th>active</th>\n",
       "      <th>moduleId</th>\n",
       "      <th>conceptId</th>\n",
       "      <th>languageCode</th>\n",
       "      <th>typeId</th>\n",
       "      <th>term</th>\n",
       "      <th>caseSignificanceId</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>101013</td>\n",
       "      <td>20170731</td>\n",
       "      <td>1</td>\n",
       "      <td>900000000000207008</td>\n",
       "      <td>126813005</td>\n",
       "      <td>en</td>\n",
       "      <td>900000000000013009</td>\n",
       "      <td>Neoplasm of anterior aspect of epiglottis</td>\n",
       "      <td>900000000000448009</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>102018</td>\n",
       "      <td>20170731</td>\n",
       "      <td>1</td>\n",
       "      <td>900000000000207008</td>\n",
       "      <td>126814004</td>\n",
       "      <td>en</td>\n",
       "      <td>900000000000013009</td>\n",
       "      <td>Neoplasm of junctional region of epiglottis</td>\n",
       "      <td>900000000000448009</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>103011</td>\n",
       "      <td>20170731</td>\n",
       "      <td>1</td>\n",
       "      <td>900000000000207008</td>\n",
       "      <td>126815003</td>\n",
       "      <td>en</td>\n",
       "      <td>900000000000013009</td>\n",
       "      <td>Neoplasm of lateral wall of oropharynx</td>\n",
       "      <td>900000000000448009</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>104017</td>\n",
       "      <td>20170731</td>\n",
       "      <td>1</td>\n",
       "      <td>900000000000207008</td>\n",
       "      <td>126816002</td>\n",
       "      <td>en</td>\n",
       "      <td>900000000000013009</td>\n",
       "      <td>Neoplasm of posterior wall of oropharynx</td>\n",
       "      <td>900000000000448009</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>105016</td>\n",
       "      <td>20170731</td>\n",
       "      <td>1</td>\n",
       "      <td>900000000000207008</td>\n",
       "      <td>126817006</td>\n",
       "      <td>en</td>\n",
       "      <td>900000000000013009</td>\n",
       "      <td>Neoplasm of esophagus</td>\n",
       "      <td>900000000000448009</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       id  effectiveTime  active            moduleId  conceptId languageCode  \\\n",
       "0  101013       20170731       1  900000000000207008  126813005           en   \n",
       "1  102018       20170731       1  900000000000207008  126814004           en   \n",
       "2  103011       20170731       1  900000000000207008  126815003           en   \n",
       "3  104017       20170731       1  900000000000207008  126816002           en   \n",
       "4  105016       20170731       1  900000000000207008  126817006           en   \n",
       "\n",
       "               typeId                                         term  \\\n",
       "0  900000000000013009    Neoplasm of anterior aspect of epiglottis   \n",
       "1  900000000000013009  Neoplasm of junctional region of epiglottis   \n",
       "2  900000000000013009       Neoplasm of lateral wall of oropharynx   \n",
       "3  900000000000013009     Neoplasm of posterior wall of oropharynx   \n",
       "4  900000000000013009                        Neoplasm of esophagus   \n",
       "\n",
       "   caseSignificanceId  \n",
       "0  900000000000448009  \n",
       "1  900000000000448009  \n",
       "2  900000000000448009  \n",
       "3  900000000000448009  \n",
       "4  900000000000448009  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "snomed_csv.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eee36071",
   "metadata": {
    "id": "eee36071"
   },
   "source": [
    "Process snomed terms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fc74afa8",
   "metadata": {
    "id": "fc74afa8"
   },
   "outputs": [],
   "source": [
    "all_ids = snomed_csv['conceptId']\n",
    "all_names = []\n",
    "for i in snomed_csv['term']:\n",
    "    try:\n",
    "        all_names.append(i.lower())\n",
    "    except:\n",
    "        all_names.append('not applicable')\n",
    "#         print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ecbc8292",
   "metadata": {
    "id": "ecbc8292"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "id                            1491117014\n",
       "effectiveTime                   20030131\n",
       "active                                 1\n",
       "moduleId              900000000000207008\n",
       "conceptId                      385432009\n",
       "languageCode                          en\n",
       "typeId                900000000000013009\n",
       "term                                 NaN\n",
       "caseSignificanceId    900000000000020002\n",
       "Name: 906846, dtype: object"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "snomed_csv.iloc[906846]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6d11f0d6",
   "metadata": {
    "id": "6d11f0d6"
   },
   "outputs": [],
   "source": [
    "snomed_name_id = [(all_names[i], all_ids[i]) for i in range(len(all_ids))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f61e031c",
   "metadata": {
    "id": "f61e031c"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1569232"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(all_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8b2c1e53",
   "metadata": {
    "id": "8b2c1e53"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['neoplasm of anterior aspect of epiglottis',\n",
       " 'neoplasm of junctional region of epiglottis',\n",
       " 'neoplasm of lateral wall of oropharynx',\n",
       " 'neoplasm of posterior wall of oropharynx',\n",
       " 'neoplasm of esophagus',\n",
       " 'neoplasm of cervical esophagus',\n",
       " 'neoplasm of thoracic esophagus',\n",
       " 'neoplasm of abdominal esophagus',\n",
       " 'neoplasm of middle third of esophagus',\n",
       " 'neoplasm of lower third of esophagus']"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_names[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "4de928c7",
   "metadata": {
    "id": "4de928c7"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    126813005\n",
       "1    126814004\n",
       "2    126815003\n",
       "3    126816002\n",
       "4    126817006\n",
       "5    126818001\n",
       "6    126819009\n",
       "7    126820003\n",
       "8    126822006\n",
       "9    126823001\n",
       "Name: conceptId, dtype: int64"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_ids[:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b808263",
   "metadata": {
    "id": "0b808263"
   },
   "source": [
    "# load pubmedbert"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "a7c7ac5b",
   "metadata": {
    "id": "a7c7ac5b"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d2c96dea",
   "metadata": {
    "id": "d2c96dea"
   },
   "outputs": [],
   "source": [
    "GPU_COUNT = torch.cuda.device_count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "5c3cfade",
   "metadata": {
    "id": "5c3cfade"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "GPU_COUNT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "7bd1e1f2",
   "metadata": {
    "id": "7bd1e1f2"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda')"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\") ## specify the GPU id's, GPU id's start from 0.\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "528023ac",
   "metadata": {
    "id": "528023ac"
   },
   "outputs": [],
   "source": [
    "# from transformers import AutoTokenizer, AutoModel\n",
    "# tokenizer = AutoTokenizer.from_pretrained(\"cambridgeltl/SapBERT-from-PubMedBERT-fulltext\")\n",
    "# model = AutoModel.from_pretrained(\"cambridgeltl/SapBERT-from-PubMedBERT-fulltext\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "Lhh12FnfuwMq",
   "metadata": {
    "id": "Lhh12FnfuwMq"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-11-15 11:09:16.635924: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2022-11-15 11:09:16.898184: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2022-11-15 11:09:18.830720: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.2/lib64:/opt/cudnn-7.6.5.32-cuda-10.2/lib64\n",
      "2022-11-15 11:09:18.830877: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.2/lib64:/opt/cudnn-7.6.5.32-cuda-10.2/lib64\n",
      "2022-11-15 11:09:18.830892: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6f2fe6e76459468cb502634e09fadc37",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f889047f747d45ecb0602cbe37b891ea",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/385 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dfd44f0b70de46a49826a8cd74c184da",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/225k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4c9a209544b54c7691798d76c563af26",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModel\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\")\n",
    "model = AutoModel.from_pretrained(\"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\")\n",
    "# model = AutoModelForMaskedLM.from_pretrained(\"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "tlzJasirUq6Y",
   "metadata": {
    "id": "tlzJasirUq6Y"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-11-15 11:10:10.057137: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2022-11-15 11:10:10.057137: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2022-11-15 11:10:10.057137: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2022-11-15 11:10:10.071504: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2022-11-15 11:10:10.318816: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2022-11-15 11:10:10.320732: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2022-11-15 11:10:10.334204: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2022-11-15 11:10:10.341789: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2022-11-15 11:10:12.013349: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.2/lib64:/opt/cudnn-7.6.5.32-cuda-10.2/lib64\n",
      "2022-11-15 11:10:12.013347: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.2/lib64:/opt/cudnn-7.6.5.32-cuda-10.2/lib64\n",
      "2022-11-15 11:10:12.013377: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.2/lib64:/opt/cudnn-7.6.5.32-cuda-10.2/lib64\n",
      "2022-11-15 11:10:12.013433: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.2/lib64:/opt/cudnn-7.6.5.32-cuda-10.2/lib64\n",
      "2022-11-15 11:10:12.013604: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.2/lib64:/opt/cudnn-7.6.5.32-cuda-10.2/lib64\n",
      "2022-11-15 11:10:12.013609: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.2/lib64:/opt/cudnn-7.6.5.32-cuda-10.2/lib64\n",
      "2022-11-15 11:10:12.013630: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n",
      "2022-11-15 11:10:12.013636: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n",
      "2022-11-15 11:10:12.013659: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.2/lib64:/opt/cudnn-7.6.5.32-cuda-10.2/lib64\n",
      "2022-11-15 11:10:12.013679: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.2/lib64:/opt/cudnn-7.6.5.32-cuda-10.2/lib64\n",
      "2022-11-15 11:10:12.013705: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n",
      "2022-11-15 11:10:12.013700: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<parallelformers.parallelize.parallelize at 0x149786bd6a60>"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# model = torch.nn.DataParallel(model)\n",
    "# model.to(device)\n",
    "from parallelformers import parallelize\n",
    "parallelize(model, num_gpus=4, fp16=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3a24048",
   "metadata": {
    "id": "a3a24048"
   },
   "source": [
    "Generate embeddings for snomed labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb0b8655",
   "metadata": {
    "id": "bb0b8655"
   },
   "outputs": [],
   "source": [
    "# all_names1 = all_names[:100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "5c5ff31c",
   "metadata": {
    "id": "5c5ff31c"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 12260/12260 [13:26<00:00, 15.20it/s]\n"
     ]
    }
   ],
   "source": [
    "bs = 128\n",
    "all_reps = []\n",
    "for i in tqdm(np.arange(0, len(all_names), bs)):\n",
    "    toks = tokenizer.batch_encode_plus(all_names[i:i+bs],\n",
    "                                      padding=\"max_length\",\n",
    "                                      max_length=25,\n",
    "                                      truncation=True,\n",
    "                                      return_tensors=\"pt\")\n",
    "    toks = toks.to(device)\n",
    "    output = model(**toks)\n",
    "    cls_rep = output[0][:,0,:]\n",
    "    \n",
    "    all_reps.append(cls_rep.cpu().detach().numpy())\n",
    "all_reps_emb = np.concatenate(all_reps, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "c1230654",
   "metadata": {
    "id": "c1230654"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1569232, 768)\n"
     ]
    }
   ],
   "source": [
    "print(all_reps_emb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "00a427c1",
   "metadata": {
    "id": "00a427c1"
   },
   "outputs": [],
   "source": [
    "all_reps_emb = all_reps_emb.astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "3c884582",
   "metadata": {
    "id": "3c884582"
   },
   "outputs": [],
   "source": [
    "import faiss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "9d7d069d",
   "metadata": {
    "id": "9d7d069d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "d = all_reps_emb.shape[1]\n",
    "index = faiss.IndexFlatL2(d)   # build the index\n",
    "print(index.is_trained)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "77b258e0",
   "metadata": {
    "id": "77b258e0"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1569232\n"
     ]
    }
   ],
   "source": [
    "index.add(all_reps_emb)                  # add vectors to the index\n",
    "print(index.ntotal)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40fe39a4",
   "metadata": {
    "id": "40fe39a4"
   },
   "source": [
    "Load ground truth data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "44851e30",
   "metadata": {
    "id": "44851e30"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>filename</th>\n",
       "      <th>mark</th>\n",
       "      <th>label</th>\n",
       "      <th>offset1</th>\n",
       "      <th>offset2</th>\n",
       "      <th>span</th>\n",
       "      <th>code</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>es-S0212-71992007000100007-1</td>\n",
       "      <td>T1</td>\n",
       "      <td>ENFERMEDAD</td>\n",
       "      <td>40</td>\n",
       "      <td>61</td>\n",
       "      <td>arterial hypertension</td>\n",
       "      <td>38341003</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>es-S0212-71992007000100007-1</td>\n",
       "      <td>T2</td>\n",
       "      <td>ENFERMEDAD</td>\n",
       "      <td>66</td>\n",
       "      <td>79</td>\n",
       "      <td>polyarthrosis</td>\n",
       "      <td>36186002</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>es-S0212-71992007000100007-1</td>\n",
       "      <td>T3</td>\n",
       "      <td>ENFERMEDAD</td>\n",
       "      <td>1682</td>\n",
       "      <td>1698</td>\n",
       "      <td>pleural effusion</td>\n",
       "      <td>60046008</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>es-S0212-71992007000100007-1</td>\n",
       "      <td>T4</td>\n",
       "      <td>ENFERMEDAD</td>\n",
       "      <td>1859</td>\n",
       "      <td>1875</td>\n",
       "      <td>pleural effusion</td>\n",
       "      <td>60046008</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>es-S0212-71992007000100007-1</td>\n",
       "      <td>T5</td>\n",
       "      <td>ENFERMEDAD</td>\n",
       "      <td>1626</td>\n",
       "      <td>1648</td>\n",
       "      <td>lower lobe atelectasis</td>\n",
       "      <td>46621007</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                       filename mark       label  offset1  offset2  \\\n",
       "0  es-S0212-71992007000100007-1   T1  ENFERMEDAD       40       61   \n",
       "1  es-S0212-71992007000100007-1   T2  ENFERMEDAD       66       79   \n",
       "2  es-S0212-71992007000100007-1   T3  ENFERMEDAD     1682     1698   \n",
       "3  es-S0212-71992007000100007-1   T4  ENFERMEDAD     1859     1875   \n",
       "4  es-S0212-71992007000100007-1   T5  ENFERMEDAD     1626     1648   \n",
       "\n",
       "                     span      code  \n",
       "0   arterial hypertension  38341003  \n",
       "1           polyarthrosis  36186002  \n",
       "2        pleural effusion  60046008  \n",
       "3        pleural effusion  60046008  \n",
       "4  lower lobe atelectasis  46621007  "
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "entities = pd.read_csv(\"entities.tsv\", delimiter=\"\\t\")\n",
    "entities.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "8a009c68",
   "metadata": {
    "id": "8a009c68",
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['arterial hypertension', 'polyarthrosis', 'pleural effusion', 'pleural effusion', 'lower lobe atelectasis', 'infectious spondylodiscitis d10-d11', 'pleural effusion', 'brucellosis', 'orchiepididymitis', 'goitre']\n",
      "0     38341003\n",
      "1     36186002\n",
      "2     60046008\n",
      "3     60046008\n",
      "4     46621007\n",
      "5    302935008\n",
      "6     60046008\n",
      "7     75702008\n",
      "8    197983000\n",
      "9      3716002\n",
      "Name: code, dtype: object\n"
     ]
    }
   ],
   "source": [
    "inp_names = [i.lower() for i in entities['span']]\n",
    "inp_labels = entities['code']\n",
    "print(inp_names[:10])\n",
    "print(inp_labels[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "id": "90bbf268",
   "metadata": {
    "id": "90bbf268"
   },
   "outputs": [],
   "source": [
    "# c=0\n",
    "# for i in inp_label:\n",
    "# #     if type(i)!=float:\n",
    "#     try:\n",
    "#         [float(i)]\n",
    "#     except:\n",
    "#         c+=1\n",
    "# #         print(i.split('+'))\n",
    "# c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "id": "49562d03",
   "metadata": {
    "id": "49562d03"
   },
   "outputs": [],
   "source": [
    "# inp_names1 = inp_names[:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6cf5d29",
   "metadata": {
    "id": "e6cf5d29"
   },
   "source": [
    "Generate embeddings for ground truth terms, get their closest snomedct embedding and list out its corresponding snomedct code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "id": "049818b3",
   "metadata": {
    "id": "049818b3"
   },
   "outputs": [],
   "source": [
    "query_toks = tokenizer.batch_encode_plus(list(inp_names),\n",
    "                                        padding = \"max_length\",\n",
    "                                        max_length = 25,\n",
    "                                        truncation=True,\n",
    "                                        return_tensors=\"pt\")\n",
    "query_toks = query_toks.to(device)\n",
    "query_output = model(**query_toks)\n",
    "query_cls_rep = query_output[0][:,0,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "id": "f0ab19b8",
   "metadata": {
    "id": "f0ab19b8"
   },
   "outputs": [],
   "source": [
    "query_cls_rep = query_cls_rep.cpu().detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "id": "3d90a519",
   "metadata": {
    "id": "3d90a519"
   },
   "outputs": [],
   "source": [
    "query_cls_rep = query_cls_rep.astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "id": "184cd570",
   "metadata": {
    "id": "184cd570"
   },
   "outputs": [],
   "source": [
    "k= 1 # take the 1 closest neighbor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "id": "ac0965a5",
   "metadata": {
    "id": "ac0965a5"
   },
   "outputs": [],
   "source": [
    "D, I = index.search(query_cls_rep, k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cfb69ca",
   "metadata": {
    "id": "4cfb69ca"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "id": "2145e65a",
   "metadata": {
    "id": "2145e65a"
   },
   "outputs": [],
   "source": [
    "pred_ids = [all_ids[i[0]] for i in I]\n",
    "# score=sum((pred_ids[i]==inp_label[i])*1 for i in range(len(pred_ids)))\n",
    "# score/len(inp_label)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85c1243b",
   "metadata": {
    "id": "85c1243b"
   },
   "source": [
    "In ground truth, zero or more than one codes are also present for each term; here only one code is predicted for each query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "id": "d7476a77",
   "metadata": {
    "id": "d7476a77"
   },
   "outputs": [],
   "source": [
    "p = [[i] for i in pred_ids]\n",
    "t = []\n",
    "for i in inp_labels:\n",
    "    try:\n",
    "        t.append([int(i)])\n",
    "    except:\n",
    "        try:\n",
    "            t.append([int(j) for j in (i.split('+'))])\n",
    "        except:\n",
    "#             print('nomap')\n",
    "            t.append([])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "id": "5a676132",
   "metadata": {
    "id": "5a676132"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "48146000"
      ]
     },
     "execution_count": 98,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "id": "424b2281",
   "metadata": {
    "id": "424b2281"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 99,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p[0][0] in t[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "id": "ae535967",
   "metadata": {
    "id": "ae535967"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "precision 0.293984962406015\n",
      "recall 0.2872887582659809\n",
      "f1 0.290598231001568\n"
     ]
    }
   ],
   "source": [
    "pre = 0\n",
    "for i in range(len(p)):\n",
    "    if p[i][0] in t[i]:\n",
    "        pre+=1\n",
    "\n",
    "pre /= len(p)\n",
    "print('precision', pre)\n",
    "\n",
    "\n",
    "rec = 0\n",
    "for i in range(len(t)):\n",
    "    if len(t[i])==1:\n",
    "        if t[i][0] in p[i]:\n",
    "            rec+=1\n",
    "    elif len(t[i])>1:\n",
    "        for j in range(len(t[i])):\n",
    "            if t[i][j] in p[i]:\n",
    "                rec+=1\n",
    "\n",
    "rec /= sum(len(i) for i in t)\n",
    "print('recall', rec)       \n",
    "\n",
    "\n",
    "f1 = 2*pre*rec/(pre+rec+np.finfo(np.float32).eps)\n",
    "print('f1', f1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e963520",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "machine_shape": "hm",
   "provenance": [
    {
     "file_id": "14SK4V1zyuaUOFhPIxRTzTnuMj0WTBpqR",
     "timestamp": 1668422445307
    }
   ]
  },
  "gpuClass": "premium",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}