959 lines (958 with data), 113.8 kB
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TODO"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
" - Embedding for all the lines of the document\n",
" <!-- - Embeddings for all concepts -->\n",
" <!-- - Each concept has a list of neighboring concepts based on similarity (e.g. cosine similarity) -->\n",
" <!-- - The searched term will be embedded and compared to all concepts -->\n",
" - The searched term will be embedded and compared to all lines of the corpus (with hashing to accelerate)\n",
" <!-- - Return patients having the neighboring concepts of the searched term -->\n",
" - Return patients that have big similarity"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"path = %pwd\n",
"if path.split(os.sep)[-1] == 'notebooks':\n",
" %cd .."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# %pip install -U sentence-transformers -q"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Importing"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# ----------------------------------- tech ----------------------------------- #\n",
"import os\n",
"import glob\n",
"import pickle\n",
"\n",
"# ---------------------------- Display and friends --------------------------- #\n",
"from tqdm import tqdm\n",
"from matplotlib import pyplot as plt\n",
"\n",
"# ------------------------- Transformers and freinds ------------------------- #\n",
"from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel\n",
"from sentence_transformers import SentenceTransformer, util\n",
"import torch\n",
"import torch.nn.functional as F\n",
"import numpy as np\n",
"\n",
"# ------------------------ Classification and friends ------------------------ #\n",
"from scipy.cluster.hierarchy import dendrogram\n",
"from sklearn.cluster import AgglomerativeClustering, KMeans\n",
"from sklearn.manifold import TSNE\n",
"\n",
"# ----------------------------------- local ---------------------------------- #\n",
"from data_preprocessing import Get_and_process_data\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Configurations"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at logs/scibert_20_epochs_64_batch_99_train_split were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of BertModel were not initialized from the model checkpoint at logs/scibert_20_epochs_64_batch_99_train_split and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"lines_per_tokenization = 5\n",
"filename_split_key = \"__at__\"\n",
"# Load model from HuggingFace Hub\n",
"device = \"cuda\"\n",
"# model_checkpoint = \"sentence-transformers/multi-qa-MiniLM-L6-cos-v1\"\n",
"# model_checkpoint = \"gsarti/scibert-nli\"\n",
"model_checkpoint = \"logs/scibert_20_epochs_64_batch_99_train_split\"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n",
"model = AutoModel.from_pretrained(model_checkpoint)\n",
"data_path = \"../data/train/txt\"\n",
"embeddings_path = data_path + os.sep + \"embeddings\"\n",
"similarity = torch.nn.CosineSimilarity()\n",
"if not os.path.exists(embeddings_path):\n",
" os.makedirs(embeddings_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### utils"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"#Mean Pooling - Take average of all tokens\n",
"def mean_pooling(model_output, attention_mask):\n",
" token_embeddings = model_output.last_hidden_state #First element of model_output contains all token embeddings\n",
" input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n",
" return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n",
"\n",
"\n",
"#Encode text\n",
"def encode(texts, tokenizer = tokenizer, model= model):\n",
" # Tokenize sentences\n",
" encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')\n",
"\n",
" # Compute token embeddings\n",
" with torch.no_grad():\n",
" model_output = model(**encoded_input, return_dict=True)\n",
"\n",
" # Perform pooling\n",
" embeddings = mean_pooling(model_output, encoded_input['attention_mask'])\n",
"\n",
" # Normalize embeddings\n",
" embeddings = F.normalize(embeddings, p=2, dim=1)\n",
" \n",
" return embeddings\n",
"\n",
"def find_cluster(query_emb, clustered_data, similarity=similarity):\n",
" best_cluster = None\n",
" best_score = -1\n",
" for i in clustered_data.keys():\n",
" center = clustered_data[i][\"center\"]\n",
" score = similarity(query_emb, center)\n",
" if score >= best_score:\n",
" best_cluster = i\n",
" best_score = score\n",
" return best_cluster\n",
"\n",
"def text_splitter(text, lines_per_tokenization=lines_per_tokenization):\n",
" lines = text.split(\"\\n\")\n",
" \n",
" texts = []\n",
" for i in range(len(lines)//lines_per_tokenization):\n",
" texts.append(\"\\n\".join(lines[i*lines_per_tokenization:(i+1)*lines_per_tokenization]))\n",
" \n",
" return texts\n",
"\n",
"def semantic_search_base(query_emb, doc_emb, docs):\n",
" #Compute dot score between query and all document embeddings\n",
" scores = torch.mm(query_emb, doc_emb.transpose(0, 1))[0].cpu().tolist()\n",
"\n",
" #Combine docs & scores\n",
" doc_score_pairs = list(zip(docs, scores))\n",
"\n",
" #Sort by decreasing score\n",
" doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)\n",
" print(doc_score_pairs)\n",
" #Output passages & scores\n",
" for doc, score in doc_score_pairs:\n",
" print(\"==> \",score) \n",
" print(doc)\n",
" \n",
"def forward(texts, tokenizer= tokenizer, model= model):\n",
" # Tokenize sentences\n",
" encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')\n",
"\n",
" # Compute token embeddings\n",
" model_output = model(**encoded_input, return_dict=True)\n",
"\n",
" # Perform pooling\n",
" embeddings = mean_pooling(model_output, encoded_input['attention_mask'])\n",
"\n",
" # Normalize embeddings\n",
" embeddings = F.normalize(embeddings, p=2, dim=1)\n",
" \n",
" return embeddings\n",
"\n",
"\n",
"def forward_doc(texts, tokenizer= tokenizer, model= model, no_grad= False):\n",
" texts = text_splitter(texts) \n",
" \n",
" # Tokenize sentences\n",
" encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')\n",
" \n",
" # Compute token embeddings\n",
" if no_grad:\n",
" with torch.no_grad():\n",
" model_output = model(**encoded_input, return_dict=True)\n",
" else :\n",
" model_output = model(**encoded_input, return_dict=True)\n",
"\n",
" # Perform pooling\n",
" embeddings = mean_pooling(model_output, encoded_input['attention_mask'])\n",
"\n",
" # NOTE: This is an easy approach\n",
" # another mean pooling over the lines of the document\n",
" # embeddings = torch.mean(embeddings_lines, 0).unsqueeze(0)\n",
" \n",
" # Normalize embeddings\n",
" embeddings = F.normalize(embeddings, p=2, dim=1)\n",
" \n",
" return embeddings\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Testing Inference from checkpoint"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"model =model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[('Around 9 Million people live in London', 0.9155025482177734), ('London is known for its financial district', 0.8950100541114807)]\n",
"==> 0.9155025482177734\n",
"Around 9 Million people live in London\n",
"==> 0.8950100541114807\n",
"London is known for its financial district\n"
]
}
],
"source": [
"# Sentences we want sentence embeddings for\n",
"query = \"How many people live in London?\"\n",
"docs = [\"Around 9 Million people live in London\", \"London is known for its financial district\"]\n",
"\n",
"#Encode query and docs\n",
"query_emb = encode(query)\n",
"doc_emb = encode(docs)\n",
"\n",
"semantic_search_base(query_emb, doc_emb, docs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"0.915637195110321 Around 9 Million people live in London\n",
"\n",
"\n",
"0.49475765228271484 London is known for its financial district"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Testing training"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"encoded_input = tokenizer(query, padding=True, truncation=True, return_tensors='pt')\n",
"model_output = model(**encoded_input, return_dict=True)\n",
"# model_output"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 9])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"encoded_input[\"input_ids\"].shape"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 9, 768])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_output.last_hidden_state.shape"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 768])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_output.pooler_output.shape"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"q shape : torch.Size([1, 768])\n",
"a shape : torch.Size([1, 768])\n"
]
}
],
"source": [
"# model.train()\n",
"\n",
"query = \"How many people live in London?\"\n",
"answer = \"Around 9 Million people live in London\"\n",
"\n",
"loss_fn = torch.nn.MSELoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)\n",
"\n",
"q = forward(query)\n",
"print(\"q shape :\", q.shape)\n",
"a = forward(answer)\n",
"print(\"a shape :\", a.shape)\n",
"\n",
"loss = loss_fn(a,q)\n",
"\n",
"optimizer.zero_grad()\n",
"# loss.backward()\n",
"# optimizer.step()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Getting data"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([14, 768])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"doc = \"\"\n",
"\n",
"with open(\"../data/train/txt/018636330_DH.txt\") as f:\n",
" doc = f.read()\n",
" \n",
"doc_emb = forward_doc(doc)\n",
"doc_emb.shape"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[('C5-6 disc herniation with cord compression and myelopathy .\\nPRINCIPAL PROCEDURE :\\nMicroscopic anterior cervical diskectomy at C5-6 and fusion .\\nHISTORY OF PRESENT ILLNESS :\\nThe patient is a 63-year-old female with a three-year history of bilateral hand numbness and occasional weakness .', 0.5795029997825623), ('Within the past year , these symptoms have progressively gotten worse , to encompass also her feet .\\nShe had a workup by her neurologist and an MRI revealed a C5-6 disc herniation with cord compression and a T2 signal change at that level .\\nPAST MEDICAL HISTORY :\\nSignificant for hypertension , hyperlipidemia .\\nMEDICATIONS ON ADMISSION :', 0.4698948562145233), ('She occasionally drinks alcohol .\\nPHYSICAL EXAMINATION :\\nShe had 5/5 strength in bilateral upper and lower extremities .\\nShe had a Hoffman 's sign greater on the right than the left and she had 10 beats of clonus in the right foot and 3-5 beats in the left foot .\\nShe had hyperreflexia in both the bilateral upper and lower extremities .', 0.35985130071640015), ('HOSPITAL COURSE :\\nThe patient tolerated a C5-6 ACDF by Dr. Miezetri Gach quite well .\\nShe had a postoperative CT scan that revealed partial decompression of the spinal canal and good placement of her hardware .\\nImmediately postop , her exam only improved slightly in her hyperreflexia .\\nShe was ambulating by postoperative day number two .', 0.2649998962879181), ('Discharge Summary\\nSigned\\nDIS\\nReport Status :\\nSigned', 0.13975289463996887), ('DISCHARGE SUMMARY\\nNAME :\\nKOTE , OA\\nUNIT NUMBER :\\n509-22-30', 0.11684238910675049), ('ADMISSION DATE :\\n06/02/2005\\nDISCHARGE DATE :\\n06/05/2005\\nPRINCIPAL DIAGNOSIS :', 0.11604052782058716), ('Electronically Signed MIEZETRI NIMIRY POP , M.D. 08/01/2005 18:50\\n_____________________________ MIEZETRI NIMIRY POP , M.D.\\nTR :\\nqg\\nDD :', 0.09493331611156464), ('Lipitor , Flexeril , hydrochlorothiazide and Norvasc .\\nALLERGIES :\\nShe has no known drug allergy .\\nSOCIAL HISTORY :\\nShe smokes one pack per day x45 years .', 0.09043849259614944), ('07/30/2005\\nTD :\\n07/31/2005 9:43 A 123524\\ncc :\\nMIEZETRI NIMIRY POP , M.D.', 0.08962328732013702), ('018636330 DH\\n5425710\\n123524\\n0144918\\n6/2/2005 12:00:00 AM', 0.08391078561544418), ('4. Lipitor , 10 mg PO daily .\\n5. Hydrochlorothiazide , 25 mg PO daily .\\n6. Norvasc , 5 mg PO daily .\\nLALIND KOTE , M.D.\\nDICTATING FOR :', 0.068220354616642), ('She tolerated a regular diet .\\nHer pain was under good control with PO pain medications and she was deemed suitable for discharge .\\nDISCHARGE ORDERS :\\nThe patient was asked to call Dr. Miezetri Gach 's office for a follow-up appointment and wound check .\\nShe is asked to call with any fevers , chills , increasing weakness or numbness or any bowel and bladder disruption .', 0.03576041758060455), ('DISCHARGE MEDICATIONS :\\nShe was discharged on the following medications .\\n1. Colace , 100 mg PO bid .\\n2. Zantac , 150 mg PO bid .\\n3. Percocet , 5/325 , 1-2 tabs PO q4-6h prn pain .', 0.029356611892580986)]\n",
"==> 0.5795029997825623\n",
"C5-6 disc herniation with cord compression and myelopathy .\n",
"PRINCIPAL PROCEDURE :\n",
"Microscopic anterior cervical diskectomy at C5-6 and fusion .\n",
"HISTORY OF PRESENT ILLNESS :\n",
"The patient is a 63-year-old female with a three-year history of bilateral hand numbness and occasional weakness .\n",
"==> 0.4698948562145233\n",
"Within the past year , these symptoms have progressively gotten worse , to encompass also her feet .\n",
"She had a workup by her neurologist and an MRI revealed a C5-6 disc herniation with cord compression and a T2 signal change at that level .\n",
"PAST MEDICAL HISTORY :\n",
"Significant for hypertension , hyperlipidemia .\n",
"MEDICATIONS ON ADMISSION :\n",
"==> 0.35985130071640015\n",
"She occasionally drinks alcohol .\n",
"PHYSICAL EXAMINATION :\n",
"She had 5/5 strength in bilateral upper and lower extremities .\n",
"She had a Hoffman 's sign greater on the right than the left and she had 10 beats of clonus in the right foot and 3-5 beats in the left foot .\n",
"She had hyperreflexia in both the bilateral upper and lower extremities .\n",
"==> 0.2649998962879181\n",
"HOSPITAL COURSE :\n",
"The patient tolerated a C5-6 ACDF by Dr. Miezetri Gach quite well .\n",
"She had a postoperative CT scan that revealed partial decompression of the spinal canal and good placement of her hardware .\n",
"Immediately postop , her exam only improved slightly in her hyperreflexia .\n",
"She was ambulating by postoperative day number two .\n",
"==> 0.13975289463996887\n",
"Discharge Summary\n",
"Signed\n",
"DIS\n",
"Report Status :\n",
"Signed\n",
"==> 0.11684238910675049\n",
"DISCHARGE SUMMARY\n",
"NAME :\n",
"KOTE , OA\n",
"UNIT NUMBER :\n",
"509-22-30\n",
"==> 0.11604052782058716\n",
"ADMISSION DATE :\n",
"06/02/2005\n",
"DISCHARGE DATE :\n",
"06/05/2005\n",
"PRINCIPAL DIAGNOSIS :\n",
"==> 0.09493331611156464\n",
"Electronically Signed MIEZETRI NIMIRY POP , M.D. 08/01/2005 18:50\n",
"_____________________________ MIEZETRI NIMIRY POP , M.D.\n",
"TR :\n",
"qg\n",
"DD :\n",
"==> 0.09043849259614944\n",
"Lipitor , Flexeril , hydrochlorothiazide and Norvasc .\n",
"ALLERGIES :\n",
"She has no known drug allergy .\n",
"SOCIAL HISTORY :\n",
"She smokes one pack per day x45 years .\n",
"==> 0.08962328732013702\n",
"07/30/2005\n",
"TD :\n",
"07/31/2005 9:43 A 123524\n",
"cc :\n",
"MIEZETRI NIMIRY POP , M.D.\n",
"==> 0.08391078561544418\n",
"018636330 DH\n",
"5425710\n",
"123524\n",
"0144918\n",
"6/2/2005 12:00:00 AM\n",
"==> 0.068220354616642\n",
"4. Lipitor , 10 mg PO daily .\n",
"5. Hydrochlorothiazide , 25 mg PO daily .\n",
"6. Norvasc , 5 mg PO daily .\n",
"LALIND KOTE , M.D.\n",
"DICTATING FOR :\n",
"==> 0.03576041758060455\n",
"She tolerated a regular diet .\n",
"Her pain was under good control with PO pain medications and she was deemed suitable for discharge .\n",
"DISCHARGE ORDERS :\n",
"The patient was asked to call Dr. Miezetri Gach 's office for a follow-up appointment and wound check .\n",
"She is asked to call with any fevers , chills , increasing weakness or numbness or any bowel and bladder disruption .\n",
"==> 0.029356611892580986\n",
"DISCHARGE MEDICATIONS :\n",
"She was discharged on the following medications .\n",
"1. Colace , 100 mg PO bid .\n",
"2. Zantac , 150 mg PO bid .\n",
"3. Percocet , 5/325 , 1-2 tabs PO q4-6h prn pain .\n"
]
}
],
"source": [
"c_emb= encode(\"bilateral hand numbness\")\n",
"semantic_search_base(c_emb, doc_emb, text_splitter(doc))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Saving embeddings"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Encoding documents: 100%|##########| 170/170 [06:48<00:00, 2.41s/it]\n"
]
}
],
"source": [
"# what are the elements in the folder ../data/train/txt/\n",
"all_docs = {}\n",
"text_files = glob.glob(data_path + os.sep + \"*.txt\")\n",
"for file in tqdm(text_files, \"Encoding documents\", ascii=True):\n",
" with open(file) as f:\n",
" doc = f.read()\n",
" file_name = os.path.basename(file).split(\".\")[0]\n",
" embeddings = forward_doc(doc, no_grad=True)\n",
" for i,emb in enumerate(embeddings):\n",
" all_docs[file_name+filename_split_key+str(i)] = emb.unsqueeze(0)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"with open(embeddings_path + os.sep + \"all_docs.pkl\", \"wb\") as f:\n",
" pickle.dump(all_docs, f)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# with open(embeddings_path + os.sep + \"all_docs.pkl\", \"rb\") as f:\n",
"# all_docs = pickle.load(f)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Classify the embeddings"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can use hierachical clustering to classify the embeddings for a very search efficient task. But for simplicity, we will only perform K-means clustering."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(3218, 768)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sample_names_list = list(map(lambda x: x[0], all_docs.items()))[:]\n",
"sample_values_list = list(map(lambda x: x[1], all_docs.items()))[:]\n",
"sample = np.array(list(map(lambda x: x.numpy().reshape(-1), sample_values_list))) # array of 1 dim vectors\n",
"sample.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Test hierachical clustering"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"clustering = AgglomerativeClustering(distance_threshold=0.7, n_clusters=None).fit(sample)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"def plot_dendrogram(model, **kwargs):\n",
" # Create linkage matrix and then plot the dendrogram\n",
"\n",
" # create the counts of samples under each node\n",
" counts = np.zeros(model.children_.shape[0])\n",
" n_samples = len(model.labels_)\n",
" for i, merge in enumerate(model.children_):\n",
" current_count = 0\n",
" for child_idx in merge:\n",
" if child_idx < n_samples:\n",
" current_count += 1 # leaf node\n",
" else:\n",
" current_count += counts[child_idx - n_samples]\n",
" counts[i] = current_count\n",
"\n",
" linkage_matrix = np.column_stack(\n",
" [model.children_, model.distances_, counts]\n",
" ).astype(float)\n",
"\n",
" # Plot the corresponding dendrogram\n",
" dendrogram(linkage_matrix, **kwargs)\n",
"\n",
"plt.title(\"Hierarchical Clustering Dendrogram\")\n",
"# plot the top three levels of the dendrogram\n",
"plot_dendrogram(clustering, truncate_mode=\"level\", p=5)\n",
"plt.xlabel(\"Number of points in node (or index of point if no parenthesis).\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Test K-means clustering"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"clustering = KMeans(n_clusters = 10).fit(sample)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/mus5900/anaconda3/envs/nlp/lib/python3.9/site-packages/sklearn/manifold/_t_sne.py:780: FutureWarning: The default initialization in TSNE will change from 'random' to 'pca' in 1.2.\n",
" warnings.warn(\n",
"/home/mus5900/anaconda3/envs/nlp/lib/python3.9/site-packages/sklearn/manifold/_t_sne.py:790: FutureWarning: The default learning rate in TSNE will change from 200.0 to 'auto' in 1.2.\n",
" warnings.warn(\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Scatter plot using TSNE\n",
"def plot_clutering(sample):\n",
" new_sample = TSNE(n_components=2).fit_transform(sample)\n",
" plt.scatter(new_sample[:, 0], new_sample[:, 1], c=clustering.labels_)\n",
" plt.show()\n",
" # # plot in 3D\n",
" # new_sample_3D = TSNE(n_components=3).fit_transform(sample)\n",
" # fig = plt.figure()\n",
" # ax = fig.add_subplot(111, projection='3d')\n",
" # ax.scatter(new_sample_3D[:, 0], new_sample_3D[:, 1], new_sample_3D[:, 2], c=clustering.labels_)\n",
" # plt.show()\n",
"plot_clutering(sample)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cluster 0 -> 212 element\n",
"cluster 1 -> 862 element\n",
"cluster 2 -> 123 element\n",
"cluster 3 -> 157 element\n",
"cluster 4 -> 345 element\n",
"cluster 5 -> 269 element\n",
"cluster 6 -> 142 element\n",
"cluster 7 -> 521 element\n",
"cluster 8 -> 366 element\n",
"cluster 9 -> 221 element\n"
]
}
],
"source": [
"for i in range(10):\n",
" print(\"cluster\", i , \"->\" , list(clustering.labels_).count(i), \"element\")"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"clustered_data = {}\n",
"for i,center in enumerate(clustering.cluster_centers_):\n",
" clustered_data[i] = {\"center\": torch.tensor(center.reshape(1, -1)), \"elements\": {}}\n",
"\n",
"for i, cluster in enumerate(clustering.labels_):\n",
" clustered_data[cluster][\"elements\"][sample_names_list[i]] = all_docs[sample_names_list[i]]"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"with open(embeddings_path + os.sep + \"clustered_data.pkl\", \"wb\") as f:\n",
" pickle.dump(clustered_data, f)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Search"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"with open(embeddings_path + os.sep + \"clustered_data.pkl\", \"rb\") as f:\n",
" clustered_data = pickle.load(f)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"center = clustered_data[0][\"center\"]"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.7883])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"similarity(query_emb, center)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"\n",
"class Buffer_best_k:\n",
" def __init__(self, k, initia_value=-float(\"inf\")):\n",
" self.k = k\n",
" self.values = [initia_value] * self.k\n",
" self.data = [None] * self.k\n",
" def new_val(self, value, data=None):\n",
" for i in range(self.k):\n",
" if self.values[i] < value:\n",
" self.values[i+1:] = self.values[i:-1]\n",
" self.data[i+1:] = self.data[i:-1]\n",
" self.values[i] = value\n",
" self.data[i] = data\n",
" return True\n",
" return False\n",
" def get_data(self):\n",
" return self.data\n",
" def get_values(self):\n",
" return self.values \n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[tensor([0.8351]), tensor([0.8331]), tensor([0.8309]), tensor([0.8288]), tensor([0.8287])]\n",
"['record-177__at__12', '920798564__at__13', 'record-144__at__7', 'record-66__at__13', '176318078_a__at__3']\n"
]
}
],
"source": [
"# query = \"DIGOXIN and AMIODARONE HCL\"\n",
"query = \"hello\"\n",
"query_emb = encode(query)\n",
"cluster = find_cluster(query_emb, clustered_data)\n",
"\n",
"buffer = Buffer_best_k(k=5)\n",
"\n",
"for name, doc_emb in clustered_data[cluster][\"elements\"].items():\n",
" score = similarity(query_emb, doc_emb)\n",
" # print(name, \"\\t{:.2f}\".format(float(score)))\n",
" buffer.new_val(score, name)\n",
" \n",
"print(buffer.get_values())\n",
"print(buffer.get_data())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# model_checkpoint = \"allenai/scibert_scivocab_uncased\"\n",
"# batch_size = 32\n",
"# tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n",
"\n",
"\n",
"# data_loader = Get_and_process_data(tokenizer, train_split=0.95, add_unlabeled=True)\n",
"# D = data_loader.get_dataset()\n",
"# label_list = data_loader.get_label_list()"
]
}
],
"metadata": {
"interpreter": {
"hash": "d650a3b94c458c273a536969f20832049d87778dafdd87df0ff4ffcada142abe"
},
"kernelspec": {
"display_name": "Python 3.9.7 64-bit ('nlp': conda)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}