--- a
+++ b/EL_roberta.ipynb
@@ -0,0 +1,1617 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "id": "138778c4",
+   "metadata": {
+    "id": "138778c4"
+   },
+   "outputs": [],
+   "source": [
+    "from tqdm import tqdm\n",
+    "import pandas as pd\n",
+    "from sklearn import metrics\n",
+    "from scipy.spatial.distance import cdist"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "id": "dTSILRD7hIHG",
+   "metadata": {
+    "id": "dTSILRD7hIHG"
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Requirement already satisfied: transformers in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (4.24.0)\n",
+      "Requirement already satisfied: tqdm>=4.27 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers) (4.64.1)\n",
+      "Requirement already satisfied: packaging>=20.0 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers) (21.3)\n",
+      "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers) (0.13.1)\n",
+      "Requirement already satisfied: regex!=2019.12.17 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers) (2022.7.25)\n",
+      "Requirement already satisfied: requests in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers) (2.27.1)\n",
+      "Requirement already satisfied: huggingface-hub<1.0,>=0.10.0 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers) (0.10.1)\n",
+      "Requirement already satisfied: filelock in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers) (3.6.0)\n",
+      "Requirement already satisfied: pyyaml>=5.1 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers) (6.0)\n",
+      "Requirement already satisfied: numpy>=1.17 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers) (1.21.5)\n",
+      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.10.0->transformers) (4.1.1)\n",
+      "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from packaging>=20.0->transformers) (2.4.7)\n",
+      "Requirement already satisfied: idna<4,>=2.5 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from requests->transformers) (3.3)\n",
+      "Requirement already satisfied: certifi>=2017.4.17 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from requests->transformers) (2021.10.8)\n",
+      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from requests->transformers) (1.26.9)\n",
+      "Requirement already satisfied: charset-normalizer~=2.0.0 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from requests->transformers) (2.0.4)\n"
+     ]
+    }
+   ],
+   "source": [
+    "!pip install transformers"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "id": "70512d4e",
+   "metadata": {
+    "id": "70512d4e"
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Requirement already satisfied: parallelformers in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (1.2.7)\n",
+      "Requirement already satisfied: dacite in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from parallelformers) (1.6.0)\n",
+      "Requirement already satisfied: transformers>=4.2 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from parallelformers) (4.24.0)\n",
+      "Requirement already satisfied: torch in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from parallelformers) (1.11.0)\n",
+      "Requirement already satisfied: packaging>=20.0 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers>=4.2->parallelformers) (21.3)\n",
+      "Requirement already satisfied: numpy>=1.17 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers>=4.2->parallelformers) (1.21.5)\n",
+      "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers>=4.2->parallelformers) (0.13.1)\n",
+      "Requirement already satisfied: pyyaml>=5.1 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers>=4.2->parallelformers) (6.0)\n",
+      "Requirement already satisfied: regex!=2019.12.17 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers>=4.2->parallelformers) (2022.7.25)\n",
+      "Requirement already satisfied: huggingface-hub<1.0,>=0.10.0 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers>=4.2->parallelformers) (0.10.1)\n",
+      "Requirement already satisfied: requests in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers>=4.2->parallelformers) (2.27.1)\n",
+      "Requirement already satisfied: tqdm>=4.27 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers>=4.2->parallelformers) (4.64.1)\n",
+      "Requirement already satisfied: filelock in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from transformers>=4.2->parallelformers) (3.6.0)\n",
+      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.10.0->transformers>=4.2->parallelformers) (4.1.1)\n",
+      "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from packaging>=20.0->transformers>=4.2->parallelformers) (2.4.7)\n",
+      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from requests->transformers>=4.2->parallelformers) (1.26.9)\n",
+      "Requirement already satisfied: idna<4,>=2.5 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from requests->transformers>=4.2->parallelformers) (3.3)\n",
+      "Requirement already satisfied: certifi>=2017.4.17 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from requests->transformers>=4.2->parallelformers) (2021.10.8)\n",
+      "Requirement already satisfied: charset-normalizer~=2.0.0 in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (from requests->transformers>=4.2->parallelformers) (2.0.4)\n"
+     ]
+    }
+   ],
+   "source": [
+    "!pip install parallelformers"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "id": "6b2ace9d",
+   "metadata": {
+    "id": "6b2ace9d"
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Requirement already satisfied: faiss-gpu in /home2/sashank.sridhar/miniconda3/envs/TripletLoss/lib/python3.9/site-packages (1.7.2)\r\n"
+     ]
+    }
+   ],
+   "source": [
+    "!pip install faiss-gpu"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "5a0d2481",
+   "metadata": {
+    "id": "5a0d2481"
+   },
+   "source": [
+    "Download snomed term-concept file from UMLS website"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "id": "ea498c9d",
+   "metadata": {
+    "id": "ea498c9d"
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "1569232\n"
+     ]
+    }
+   ],
+   "source": [
+    "snomed_csv = pd.read_csv(\"sct2_Description_Snapshot-en_INT_20220831.txt\", delimiter=\"\\t\")\n",
+    "print(len(snomed_csv))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "id": "DkqmgM5uxfob",
+   "metadata": {
+    "id": "DkqmgM5uxfob"
+   },
+   "outputs": [],
+   "source": [
+    "# from google.colab import drive\n",
+    "# drive.mount('/content/drive')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "id": "d811cd3c",
+   "metadata": {
+    "id": "d811cd3c"
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "Index(['id', 'effectiveTime', 'active', 'moduleId', 'conceptId',\n",
+       "       'languageCode', 'typeId', 'term', 'caseSignificanceId'],\n",
+       "      dtype='object')"
+      ]
+     },
+     "execution_count": 5,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "snomed_csv.columns"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "id": "3327f0d9",
+   "metadata": {
+    "id": "3327f0d9"
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>id</th>\n",
+       "      <th>effectiveTime</th>\n",
+       "      <th>active</th>\n",
+       "      <th>moduleId</th>\n",
+       "      <th>conceptId</th>\n",
+       "      <th>languageCode</th>\n",
+       "      <th>typeId</th>\n",
+       "      <th>term</th>\n",
+       "      <th>caseSignificanceId</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>101013</td>\n",
+       "      <td>20170731</td>\n",
+       "      <td>1</td>\n",
+       "      <td>900000000000207008</td>\n",
+       "      <td>126813005</td>\n",
+       "      <td>en</td>\n",
+       "      <td>900000000000013009</td>\n",
+       "      <td>Neoplasm of anterior aspect of epiglottis</td>\n",
+       "      <td>900000000000448009</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>102018</td>\n",
+       "      <td>20170731</td>\n",
+       "      <td>1</td>\n",
+       "      <td>900000000000207008</td>\n",
+       "      <td>126814004</td>\n",
+       "      <td>en</td>\n",
+       "      <td>900000000000013009</td>\n",
+       "      <td>Neoplasm of junctional region of epiglottis</td>\n",
+       "      <td>900000000000448009</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>103011</td>\n",
+       "      <td>20170731</td>\n",
+       "      <td>1</td>\n",
+       "      <td>900000000000207008</td>\n",
+       "      <td>126815003</td>\n",
+       "      <td>en</td>\n",
+       "      <td>900000000000013009</td>\n",
+       "      <td>Neoplasm of lateral wall of oropharynx</td>\n",
+       "      <td>900000000000448009</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>104017</td>\n",
+       "      <td>20170731</td>\n",
+       "      <td>1</td>\n",
+       "      <td>900000000000207008</td>\n",
+       "      <td>126816002</td>\n",
+       "      <td>en</td>\n",
+       "      <td>900000000000013009</td>\n",
+       "      <td>Neoplasm of posterior wall of oropharynx</td>\n",
+       "      <td>900000000000448009</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>4</th>\n",
+       "      <td>105016</td>\n",
+       "      <td>20170731</td>\n",
+       "      <td>1</td>\n",
+       "      <td>900000000000207008</td>\n",
+       "      <td>126817006</td>\n",
+       "      <td>en</td>\n",
+       "      <td>900000000000013009</td>\n",
+       "      <td>Neoplasm of esophagus</td>\n",
+       "      <td>900000000000448009</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "       id  effectiveTime  active            moduleId  conceptId languageCode  \\\n",
+       "0  101013       20170731       1  900000000000207008  126813005           en   \n",
+       "1  102018       20170731       1  900000000000207008  126814004           en   \n",
+       "2  103011       20170731       1  900000000000207008  126815003           en   \n",
+       "3  104017       20170731       1  900000000000207008  126816002           en   \n",
+       "4  105016       20170731       1  900000000000207008  126817006           en   \n",
+       "\n",
+       "               typeId                                         term  \\\n",
+       "0  900000000000013009    Neoplasm of anterior aspect of epiglottis   \n",
+       "1  900000000000013009  Neoplasm of junctional region of epiglottis   \n",
+       "2  900000000000013009       Neoplasm of lateral wall of oropharynx   \n",
+       "3  900000000000013009     Neoplasm of posterior wall of oropharynx   \n",
+       "4  900000000000013009                        Neoplasm of esophagus   \n",
+       "\n",
+       "   caseSignificanceId  \n",
+       "0  900000000000448009  \n",
+       "1  900000000000448009  \n",
+       "2  900000000000448009  \n",
+       "3  900000000000448009  \n",
+       "4  900000000000448009  "
+      ]
+     },
+     "execution_count": 6,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "snomed_csv.head()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "eee36071",
+   "metadata": {
+    "id": "eee36071"
+   },
+   "source": [
+    "Process snomed terms"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "id": "fc74afa8",
+   "metadata": {
+    "id": "fc74afa8"
+   },
+   "outputs": [],
+   "source": [
+    "all_ids = snomed_csv['conceptId']\n",
+    "all_names = []\n",
+    "for i in snomed_csv['term']:\n",
+    "    try:\n",
+    "        all_names.append(i.lower())\n",
+    "    except:\n",
+    "        all_names.append('not applicable')\n",
+    "#         print(i)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "id": "ecbc8292",
+   "metadata": {
+    "id": "ecbc8292"
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "id                            1491117014\n",
+       "effectiveTime                   20030131\n",
+       "active                                 1\n",
+       "moduleId              900000000000207008\n",
+       "conceptId                      385432009\n",
+       "languageCode                          en\n",
+       "typeId                900000000000013009\n",
+       "term                                 NaN\n",
+       "caseSignificanceId    900000000000020002\n",
+       "Name: 906846, dtype: object"
+      ]
+     },
+     "execution_count": 8,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "snomed_csv.iloc[906846]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "id": "6d11f0d6",
+   "metadata": {
+    "id": "6d11f0d6"
+   },
+   "outputs": [],
+   "source": [
+    "snomed_name_id = [(all_names[i], all_ids[i]) for i in range(len(all_ids))]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "id": "f61e031c",
+   "metadata": {
+    "id": "f61e031c"
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "1569232"
+      ]
+     },
+     "execution_count": 10,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "len(all_ids)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "id": "8b2c1e53",
+   "metadata": {
+    "id": "8b2c1e53"
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "['neoplasm of anterior aspect of epiglottis',\n",
+       " 'neoplasm of junctional region of epiglottis',\n",
+       " 'neoplasm of lateral wall of oropharynx',\n",
+       " 'neoplasm of posterior wall of oropharynx',\n",
+       " 'neoplasm of esophagus',\n",
+       " 'neoplasm of cervical esophagus',\n",
+       " 'neoplasm of thoracic esophagus',\n",
+       " 'neoplasm of abdominal esophagus',\n",
+       " 'neoplasm of middle third of esophagus',\n",
+       " 'neoplasm of lower third of esophagus']"
+      ]
+     },
+     "execution_count": 11,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "all_names[:10]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "id": "4de928c7",
+   "metadata": {
+    "id": "4de928c7"
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "0    126813005\n",
+       "1    126814004\n",
+       "2    126815003\n",
+       "3    126816002\n",
+       "4    126817006\n",
+       "5    126818001\n",
+       "6    126819009\n",
+       "7    126820003\n",
+       "8    126822006\n",
+       "9    126823001\n",
+       "Name: conceptId, dtype: int64"
+      ]
+     },
+     "execution_count": 12,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "all_ids[:10]"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "0b808263",
+   "metadata": {
+    "id": "0b808263"
+   },
+   "source": [
+    "# load distil-bert"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "id": "a7c7ac5b",
+   "metadata": {
+    "id": "a7c7ac5b"
+   },
+   "outputs": [],
+   "source": [
+    "import numpy as np\n",
+    "import torch"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "id": "d2c96dea",
+   "metadata": {
+    "id": "d2c96dea"
+   },
+   "outputs": [],
+   "source": [
+    "GPU_COUNT = torch.cuda.device_count()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "id": "5c3cfade",
+   "metadata": {
+    "id": "5c3cfade"
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "4"
+      ]
+     },
+     "execution_count": 15,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "GPU_COUNT"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "id": "7bd1e1f2",
+   "metadata": {
+    "id": "7bd1e1f2"
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "device(type='cuda')"
+      ]
+     },
+     "execution_count": 16,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\") ## specify the GPU id's, GPU id's start from 0.\n",
+    "device"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "id": "528023ac",
+   "metadata": {
+    "id": "528023ac"
+   },
+   "outputs": [],
+   "source": [
+    "# from transformers import AutoTokenizer, AutoModel\n",
+    "# tokenizer = AutoTokenizer.from_pretrained(\"cambridgeltl/SapBERT-from-PubMedBERT-fulltext\")\n",
+    "# model = AutoModel.from_pretrained(\"cambridgeltl/SapBERT-from-PubMedBERT-fulltext\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "id": "Lhh12FnfuwMq",
+   "metadata": {
+    "id": "Lhh12FnfuwMq"
+   },
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "2022-11-15 12:05:39.754921: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
+      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+      "2022-11-15 12:05:40.005646: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
+      "2022-11-15 12:05:41.748569: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.2/lib64:/opt/cudnn-7.6.5.32-cuda-10.2/lib64\n",
+      "2022-11-15 12:05:41.748707: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.2/lib64:/opt/cudnn-7.6.5.32-cuda-10.2/lib64\n",
+      "2022-11-15 12:05:41.748720: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2083d5e32bda48c998f9d365ec40fef4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Downloading:   0%|          | 0.00/328 [00:00<?, ?B/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "66d0f2238efe45729031254c7bbd2fbc",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Downloading:   0%|          | 0.00/798k [00:00<?, ?B/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c7a64902f30647bca2f0e83385aa5f2d",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6122de951cf44f34bb2d13f38255480c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "aefb9a9b8d674b2192cfe34df4ea4f62",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Downloading:   0%|          | 0.00/239 [00:00<?, ?B/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2611f4f9b6364f36a4a0c68a04cb0104",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Downloading:   0%|          | 0.00/674 [00:00<?, ?B/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c843f271b56740e399afb63e5f8bac47",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Downloading:   0%|          | 0.00/1.42G [00:00<?, ?B/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Some weights of the model checkpoint at raynardj/pmc-med-bio-mlm-roberta-large were not used when initializing RobertaModel: ['lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.bias']\n",
+      "- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+      "- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+      "Some weights of RobertaModel were not initialized from the model checkpoint at raynardj/pmc-med-bio-mlm-roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
+      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
+     ]
+    }
+   ],
+   "source": [
+    "from transformers import AutoTokenizer, AutoModel\n",
+    "tokenizer = AutoTokenizer.from_pretrained(\"raynardj/pmc-med-bio-mlm-roberta-large\")\n",
+    "model = AutoModel.from_pretrained(\"raynardj/pmc-med-bio-mlm-roberta-large\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 18,
+   "id": "tlzJasirUq6Y",
+   "metadata": {
+    "id": "tlzJasirUq6Y"
+   },
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "2022-11-15 12:08:52.000654: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
+      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+      "2022-11-15 12:08:52.000654: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
+      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+      "2022-11-15 12:08:52.205237: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
+      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+      "2022-11-15 12:08:52.268140: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
+      "2022-11-15 12:08:52.268140: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
+      "2022-11-15 12:08:52.471809: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
+      "2022-11-15 12:08:52.482557: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
+      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+      "2022-11-15 12:08:52.774190: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
+      "2022-11-15 12:08:54.024346: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.2/lib64:/opt/cudnn-7.6.5.32-cuda-10.2/lib64\n",
+      "2022-11-15 12:08:54.024512: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.2/lib64:/opt/cudnn-7.6.5.32-cuda-10.2/lib64\n",
+      "2022-11-15 12:08:54.024529: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n",
+      "2022-11-15 12:08:54.028393: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.2/lib64:/opt/cudnn-7.6.5.32-cuda-10.2/lib64\n",
+      "2022-11-15 12:08:54.028615: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.2/lib64:/opt/cudnn-7.6.5.32-cuda-10.2/lib64\n",
+      "2022-11-15 12:08:54.028639: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n",
+      "2022-11-15 12:08:54.029519: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.2/lib64:/opt/cudnn-7.6.5.32-cuda-10.2/lib64\n",
+      "2022-11-15 12:08:54.029727: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.2/lib64:/opt/cudnn-7.6.5.32-cuda-10.2/lib64\n",
+      "2022-11-15 12:08:54.029759: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n",
+      "2022-11-15 12:08:54.275190: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.2/lib64:/opt/cudnn-7.6.5.32-cuda-10.2/lib64\n",
+      "2022-11-15 12:08:54.275386: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.2/lib64:/opt/cudnn-7.6.5.32-cuda-10.2/lib64\n",
+      "2022-11-15 12:08:54.275407: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<parallelformers.parallelize.parallelize at 0x1470568aeaf0>"
+      ]
+     },
+     "execution_count": 18,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# model = torch.nn.DataParallel(model)\n",
+    "# model = model.to(device)\n",
+    "from parallelformers import parallelize\n",
+    "parallelize(model, num_gpus=4, fp16=True)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "a3a24048",
+   "metadata": {
+    "id": "a3a24048"
+   },
+   "source": [
+    "Generate embeddings for snomed labels"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 27,
+   "id": "bb0b8655",
+   "metadata": {
+    "id": "bb0b8655"
+   },
+   "outputs": [],
+   "source": [
+    "# all_names1 = all_names[:100]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "id": "5c5ff31c",
+   "metadata": {
+    "id": "5c5ff31c"
+   },
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "100%|█████████████████████████████████████| 12260/12260 [23:53<00:00,  8.55it/s]\n"
+     ]
+    }
+   ],
+   "source": [
+    "bs = 128\n",
+    "all_reps = []\n",
+    "for i in tqdm(np.arange(0, len(all_names), bs)):\n",
+    "    toks = tokenizer.batch_encode_plus(all_names[i:i+bs],\n",
+    "                                      padding=\"max_length\",\n",
+    "                                      max_length=25,\n",
+    "                                      truncation=True,\n",
+    "                                      return_tensors=\"pt\")\n",
+    "#     print(device)\n",
+    "#     model = model.to(device)\n",
+    "    toks = toks.to(device)\n",
+    "    output = model(**toks)\n",
+    "    cls_rep = output[0][:,0,:]\n",
+    "    \n",
+    "    all_reps.append(cls_rep.cpu().detach().numpy())\n",
+    "all_reps_emb = np.concatenate(all_reps, axis=0)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "id": "c1230654",
+   "metadata": {
+    "id": "c1230654"
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "(1569232, 1024)\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(all_reps_emb.shape)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 21,
+   "id": "00a427c1",
+   "metadata": {
+    "id": "00a427c1"
+   },
+   "outputs": [],
+   "source": [
+    "all_reps_emb = all_reps_emb.astype(np.float32)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 22,
+   "id": "3c884582",
+   "metadata": {
+    "id": "3c884582"
+   },
+   "outputs": [],
+   "source": [
+    "import faiss"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "id": "9d7d069d",
+   "metadata": {
+    "id": "9d7d069d"
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "True\n"
+     ]
+    }
+   ],
+   "source": [
+    "d = all_reps_emb.shape[1]\n",
+    "index = faiss.IndexFlatL2(d)   # build the index\n",
+    "print(index.is_trained)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 24,
+   "id": "77b258e0",
+   "metadata": {
+    "id": "77b258e0"
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "1569232\n"
+     ]
+    }
+   ],
+   "source": [
+    "index.add(all_reps_emb)                  # add vectors to the index\n",
+    "print(index.ntotal)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "40fe39a4",
+   "metadata": {
+    "id": "40fe39a4"
+   },
+   "source": [
+    "Load ground truth data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 25,
+   "id": "44851e30",
+   "metadata": {
+    "id": "44851e30"
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>filename</th>\n",
+       "      <th>mark</th>\n",
+       "      <th>label</th>\n",
+       "      <th>offset1</th>\n",
+       "      <th>offset2</th>\n",
+       "      <th>span</th>\n",
+       "      <th>code</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>es-S0212-71992007000100007-1</td>\n",
+       "      <td>T1</td>\n",
+       "      <td>ENFERMEDAD</td>\n",
+       "      <td>40</td>\n",
+       "      <td>61</td>\n",
+       "      <td>arterial hypertension</td>\n",
+       "      <td>38341003</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>es-S0212-71992007000100007-1</td>\n",
+       "      <td>T2</td>\n",
+       "      <td>ENFERMEDAD</td>\n",
+       "      <td>66</td>\n",
+       "      <td>79</td>\n",
+       "      <td>polyarthrosis</td>\n",
+       "      <td>36186002</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>es-S0212-71992007000100007-1</td>\n",
+       "      <td>T3</td>\n",
+       "      <td>ENFERMEDAD</td>\n",
+       "      <td>1682</td>\n",
+       "      <td>1698</td>\n",
+       "      <td>pleural effusion</td>\n",
+       "      <td>60046008</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>es-S0212-71992007000100007-1</td>\n",
+       "      <td>T4</td>\n",
+       "      <td>ENFERMEDAD</td>\n",
+       "      <td>1859</td>\n",
+       "      <td>1875</td>\n",
+       "      <td>pleural effusion</td>\n",
+       "      <td>60046008</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>4</th>\n",
+       "      <td>es-S0212-71992007000100007-1</td>\n",
+       "      <td>T5</td>\n",
+       "      <td>ENFERMEDAD</td>\n",
+       "      <td>1626</td>\n",
+       "      <td>1648</td>\n",
+       "      <td>lower lobe atelectasis</td>\n",
+       "      <td>46621007</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "                       filename mark       label  offset1  offset2  \\\n",
+       "0  es-S0212-71992007000100007-1   T1  ENFERMEDAD       40       61   \n",
+       "1  es-S0212-71992007000100007-1   T2  ENFERMEDAD       66       79   \n",
+       "2  es-S0212-71992007000100007-1   T3  ENFERMEDAD     1682     1698   \n",
+       "3  es-S0212-71992007000100007-1   T4  ENFERMEDAD     1859     1875   \n",
+       "4  es-S0212-71992007000100007-1   T5  ENFERMEDAD     1626     1648   \n",
+       "\n",
+       "                     span      code  \n",
+       "0   arterial hypertension  38341003  \n",
+       "1           polyarthrosis  36186002  \n",
+       "2        pleural effusion  60046008  \n",
+       "3        pleural effusion  60046008  \n",
+       "4  lower lobe atelectasis  46621007  "
+      ]
+     },
+     "execution_count": 25,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "entities = pd.read_csv(\"entities.tsv\", delimiter=\"\\t\")\n",
+    "entities.head()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 26,
+   "id": "8a009c68",
+   "metadata": {
+    "id": "8a009c68",
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "['arterial hypertension', 'polyarthrosis', 'pleural effusion', 'pleural effusion', 'lower lobe atelectasis', 'infectious spondylodiscitis d10-d11', 'pleural effusion', 'brucellosis', 'orchiepididymitis', 'goitre']\n",
+      "0     38341003\n",
+      "1     36186002\n",
+      "2     60046008\n",
+      "3     60046008\n",
+      "4     46621007\n",
+      "5    302935008\n",
+      "6     60046008\n",
+      "7     75702008\n",
+      "8    197983000\n",
+      "9      3716002\n",
+      "Name: code, dtype: object\n"
+     ]
+    }
+   ],
+   "source": [
+    "inp_names = [i.lower() for i in entities['span']]\n",
+    "inp_labels = entities['code']\n",
+    "print(inp_names[:10])\n",
+    "print(inp_labels[:10])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 27,
+   "id": "90bbf268",
+   "metadata": {
+    "id": "90bbf268"
+   },
+   "outputs": [],
+   "source": [
+    "# c=0\n",
+    "# for i in inp_label:\n",
+    "# #     if type(i)!=float:\n",
+    "#     try:\n",
+    "#         [float(i)]\n",
+    "#     except:\n",
+    "#         c+=1\n",
+    "# #         print(i.split('+'))\n",
+    "# c"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 28,
+   "id": "49562d03",
+   "metadata": {
+    "id": "49562d03"
+   },
+   "outputs": [],
+   "source": [
+    "# inp_names1 = inp_names[:10]"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "e6cf5d29",
+   "metadata": {
+    "id": "e6cf5d29"
+   },
+   "source": [
+    "Generate embeddings for ground truth terms, get their closest snomedct embedding and list out its corresponding snomedct code"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 29,
+   "id": "049818b3",
+   "metadata": {
+    "id": "049818b3"
+   },
+   "outputs": [],
+   "source": [
+    "query_toks = tokenizer.batch_encode_plus(list(inp_names),\n",
+    "                                        padding = \"max_length\",\n",
+    "                                        max_length = 25,\n",
+    "                                        truncation=True,\n",
+    "                                        return_tensors=\"pt\")\n",
+    "query_toks = query_toks.to(device)\n",
+    "query_output = model(**query_toks)\n",
+    "query_cls_rep = query_output[0][:,0,:]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 30,
+   "id": "f0ab19b8",
+   "metadata": {
+    "id": "f0ab19b8"
+   },
+   "outputs": [],
+   "source": [
+    "query_cls_rep = query_cls_rep.cpu().detach().numpy()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 31,
+   "id": "3d90a519",
+   "metadata": {
+    "id": "3d90a519"
+   },
+   "outputs": [],
+   "source": [
+    "query_cls_rep = query_cls_rep.astype(np.float32)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 32,
+   "id": "184cd570",
+   "metadata": {
+    "id": "184cd570"
+   },
+   "outputs": [],
+   "source": [
+    "k= 1 # take the 1 closest neighbor"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 33,
+   "id": "ac0965a5",
+   "metadata": {
+    "id": "ac0965a5"
+   },
+   "outputs": [],
+   "source": [
+    "D, I = index.search(query_cls_rep, k)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 34,
+   "id": "e7fada1e",
+   "metadata": {
+    "id": "e7fada1e"
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[[867885]\n",
+      " [ 58583]\n",
+      " [226122]\n",
+      " [226122]\n",
+      " [449744]]\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(I[:5])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "4cfb69ca",
+   "metadata": {
+    "id": "4cfb69ca"
+   },
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 35,
+   "id": "2145e65a",
+   "metadata": {
+    "id": "2145e65a"
+   },
+   "outputs": [],
+   "source": [
+    "pred_ids = [all_ids[i[0]] for i in I]\n",
+    "# score=sum((pred_ids[i]==inp_label[i])*1 for i in range(len(pred_ids)))\n",
+    "# score/len(inp_label)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "85c1243b",
+   "metadata": {
+    "id": "85c1243b"
+   },
+   "source": [
+    "In ground truth, zero or more than one codes are also present for each term; here only one code is predicted for each query"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 36,
+   "id": "d7476a77",
+   "metadata": {
+    "id": "d7476a77"
+   },
+   "outputs": [],
+   "source": [
+    "p = [[i] for i in pred_ids]\n",
+    "t = []\n",
+    "for i in inp_labels:\n",
+    "    try:\n",
+    "        t.append([int(i)])\n",
+    "    except:\n",
+    "        try:\n",
+    "            t.append([int(j) for j in (i.split('+'))])\n",
+    "        except:\n",
+    "#             print('nomap')\n",
+    "            t.append([])\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 37,
+   "id": "5a676132",
+   "metadata": {
+    "id": "5a676132"
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "68109007"
+      ]
+     },
+     "execution_count": 37,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "p[0][0]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 38,
+   "id": "424b2281",
+   "metadata": {
+    "id": "424b2281"
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "False"
+      ]
+     },
+     "execution_count": 38,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "p[0][0] in t[0]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 39,
+   "id": "ae535967",
+   "metadata": {
+    "id": "ae535967"
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "precision 0.2688721804511278\n",
+      "recall 0.262747979426892\n",
+      "f1 0.26577474530851825\n"
+     ]
+    }
+   ],
+   "source": [
+    "pre = 0\n",
+    "for i in range(len(p)):\n",
+    "    if p[i][0] in t[i]:\n",
+    "        pre+=1\n",
+    "\n",
+    "pre /= len(p)\n",
+    "print('precision', pre)\n",
+    "\n",
+    "\n",
+    "rec = 0\n",
+    "for i in range(len(t)):\n",
+    "    if len(t[i])==1:\n",
+    "        if t[i][0] in p[i]:\n",
+    "            rec+=1\n",
+    "    elif len(t[i])>1:\n",
+    "        for j in range(len(t[i])):\n",
+    "            if t[i][j] in p[i]:\n",
+    "                rec+=1\n",
+    "\n",
+    "rec /= sum(len(i) for i in t)\n",
+    "print('recall', rec)       \n",
+    "\n",
+    "\n",
+    "f1 = 2*pre*rec/(pre+rec+np.finfo(np.float32).eps)\n",
+    "print('f1', f1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "7ab6b56e",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "55fdae6c",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 40,
+   "id": "b9ab4a63",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "k = 3\n",
+    "D, I = index.search(query_cls_rep, k)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 41,
+   "id": "dca0e636",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[[867885 865181 379725]\n",
+      " [ 58583 758734 293892]\n",
+      " [226122 518969  96888]\n",
+      " [226122 518969  96888]\n",
+      " [449744 166059  11946]]\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(I[:5])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 42,
+   "id": "a33e6ba9",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pred_ids = [all_ids[i[0]] for i in I]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 43,
+   "id": "6e84be8e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "p = [[i] for i in pred_ids]\n",
+    "t = []\n",
+    "for i in inp_labels:\n",
+    "    try:\n",
+    "        t.append([int(i)])\n",
+    "    except:\n",
+    "        try:\n",
+    "            t.append([int(j) for j in (i.split('+'))])\n",
+    "        except:\n",
+    "#             print('nomap')\n",
+    "            t.append([])\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 44,
+   "id": "e87ce787",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "precision 0.25533834586466164\n",
+      "recall 0.24952240999265246\n",
+      "f1 0.25239681888711424\n"
+     ]
+    }
+   ],
+   "source": [
+    "pre = 0\n",
+    "for i in range(len(p)):\n",
+    "    if p[i][0] in t[i]:\n",
+    "        pre+=1\n",
+    "\n",
+    "pre /= len(p)\n",
+    "print('precision', pre)\n",
+    "\n",
+    "\n",
+    "rec = 0\n",
+    "for i in range(len(t)):\n",
+    "    if len(t[i])==1:\n",
+    "        if t[i][0] in p[i]:\n",
+    "            rec+=1\n",
+    "    elif len(t[i])>1:\n",
+    "        for j in range(len(t[i])):\n",
+    "            if t[i][j] in p[i]:\n",
+    "                rec+=1\n",
+    "\n",
+    "rec /= sum(len(i) for i in t)\n",
+    "print('recall', rec)       \n",
+    "\n",
+    "\n",
+    "f1 = 2*pre*rec/(pre+rec+np.finfo(np.float32).eps)\n",
+    "print('f1', f1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "c6979075",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "656dd68c",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 45,
+   "id": "e50dd30e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "k = 5\n",
+    "D, I = index.search(query_cls_rep, k)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 46,
+   "id": "38f5b4bc",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pred_ids = [all_ids[i[0]] for i in I]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 47,
+   "id": "799627a2",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "p = [[i] for i in pred_ids]\n",
+    "t = []\n",
+    "for i in inp_labels:\n",
+    "    try:\n",
+    "        t.append([int(i)])\n",
+    "    except:\n",
+    "        try:\n",
+    "            t.append([int(j) for j in (i.split('+'))])\n",
+    "        except:\n",
+    "#             print('nomap')\n",
+    "            t.append([])\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 48,
+   "id": "5ae9ead8",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "precision 0.2518796992481203\n",
+      "recall 0.2461425422483468\n",
+      "f1 0.24897801546831108\n"
+     ]
+    }
+   ],
+   "source": [
+    "pre = 0\n",
+    "for i in range(len(p)):\n",
+    "    if p[i][0] in t[i]:\n",
+    "        pre+=1\n",
+    "\n",
+    "pre /= len(p)\n",
+    "print('precision', pre)\n",
+    "\n",
+    "\n",
+    "rec = 0\n",
+    "for i in range(len(t)):\n",
+    "    if len(t[i])==1:\n",
+    "        if t[i][0] in p[i]:\n",
+    "            rec+=1\n",
+    "    elif len(t[i])>1:\n",
+    "        for j in range(len(t[i])):\n",
+    "            if t[i][j] in p[i]:\n",
+    "                rec+=1\n",
+    "\n",
+    "rec /= sum(len(i) for i in t)\n",
+    "print('recall', rec)       \n",
+    "\n",
+    "\n",
+    "f1 = 2*pre*rec/(pre+rec+np.finfo(np.float32).eps)\n",
+    "print('f1', f1)"
+   ]
+  }
+ ],
+ "metadata": {
+  "accelerator": "GPU",
+  "colab": {
+   "collapsed_sections": [],
+   "machine_shape": "hm",
+   "provenance": [
+    {
+     "file_id": "1oudre73ErxBNqU5Dw6jISmnRyQ99pLBu",
+     "timestamp": 1668423723333
+    },
+    {
+     "file_id": "14SK4V1zyuaUOFhPIxRTzTnuMj0WTBpqR",
+     "timestamp": 1668422445307
+    }
+   ]
+  },
+  "gpuClass": "premium",
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.9.12"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}