--- a +++ b/LLM-Zero-shot_approach/evaluation.ipynb @@ -0,0 +1,1061 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Evaluation of the Zero-shot approach for NER**" + ] + }, + { + "cell_type": "code", + "execution_count": 641, + "metadata": {}, + "outputs": [], + "source": [ + "# imports\n", + "import os\n", + "import torch\n", + "from transformers import AutoTokenizer\n", + "from datasets import Dataset, DatasetDict, load_metric\n", + "import pandas as pd\n", + "import evaluate" + ] + }, + { + "cell_type": "code", + "execution_count": 642, + "metadata": {}, + "outputs": [], + "source": [ + "# paths\n", + "root_path = \"..\"\n", + "# root_path = \"./drive/MyDrive/HandsOnNLP\" # for google colab\n", + "data_path = f'{root_path}/data'\n", + "annotations_path = f'{data_path}/Annotations_Mistral_Prompt_2'\n", + "chia_bio_path = f'{data_path}/chia/chia_bio'" + ] + }, + { + "cell_type": "code", + "execution_count": 643, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "200" + ] + }, + "execution_count": 643, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ann_files = os.listdir(annotations_path)\n", + "len(ann_files)" + ] + }, + { + "cell_type": "code", + "execution_count": 644, + "metadata": {}, + "outputs": [], + "source": [ + "true_ann = {} # list with real annotations\n", + "mistral_ann = {} # list with mistral annotations" + ] + }, + { + "cell_type": "code", + "execution_count": 645, + "metadata": {}, + "outputs": [], + "source": [ + "# dict for the entities (entity to int value)\n", + "simple_ent = {\"Condition\", \"Value\", \"Drug\", \"Procedure\", \"Measurement\", \"Temporal\", \"Observation\", \"Person\", \"Mood\", \"Pregnancy_considerations\", \"Device\"}\n", + "sel_ent = {\n", + " \"O\": 0,\n", + " \"B-Condition\": 1,\n", + " \"I-Condition\": 2,\n", + " \"B-Value\": 3,\n", + " \"I-Value\": 4,\n", + " \"B-Drug\": 5,\n", + " \"I-Drug\": 6,\n", + " \"B-Procedure\": 7,\n", + " \"I-Procedure\": 8,\n", + " \"B-Measurement\": 9,\n", + " \"I-Measurement\": 10,\n", + " \"B-Temporal\": 11,\n", + " \"I-Temporal\": 12,\n", + " \"B-Observation\": 13,\n", + " \"I-Observation\": 14,\n", + " \"B-Person\": 15,\n", + " \"I-Person\": 16,\n", + " \"B-Mood\": 17,\n", + " \"I-Mood\": 18,\n", + " \"B-Pregnancy_considerations\": 19,\n", + " \"I-Pregnancy_considerations\": 20,\n", + " \"B-Device\": 21,\n", + " \"I-Device\": 22\n", + "}\n", + "\n", + "entities_list = list(sel_ent.keys())\n", + "sel_ent_inv = {v: k for k, v in sel_ent.items()}" + ] + }, + { + "cell_type": "code", + "execution_count": 646, + "metadata": {}, + "outputs": [], + "source": [ + "import re" + ] + }, + { + "cell_type": "code", + "execution_count": 647, + "metadata": {}, + "outputs": [], + "source": [ + "def parse_ann2bio(sentence, pattern, pattern1, pattern2):\n", + " if sentence[-1] == \"\\n\":\n", + " sentence = sentence[:-2] # remove the \\n and a final point wrongly added\n", + " else:\n", + " sentence = sentence[:-1] # remove the final point wrongly added\n", + " \n", + " # find the entities\n", + " occurrences = re.finditer(pattern, sentence)\n", + " indexes = [(match.start(), match.end()) for match in occurrences]\n", + "\n", + " annotation = []\n", + " i = 0\n", + " # create the bio list\n", + " for beg, end in indexes:\n", + " if beg > i:\n", + " annotation.extend([(word, \"O\") for word in sentence[i:beg].split()])\n", + " entity = sentence[beg:end]\n", + " entity_name = re.search(pattern1, entity).group(1)\n", + " entity = entity.replace(f'<{entity_name}>', \"\").replace(f'</{entity_name}>', \"\")\n", + " split_entity = entity.split()\n", + " annotation.append((split_entity[0], \"B-\" + entity_name))\n", + " annotation.extend([(word, \"I-\" + entity_name) for word in split_entity[1:]])\n", + " i = end\n", + " annotation.extend([(word, \"O\") for word in sentence[i:].split()])\n", + "\n", + " # check punctuation sign in tokens and put them as individual tokens\n", + " ps = r'(\\.|\\,|\\:|\\;|\\!|\\?|\\-|\\(|\\)|\\[|\\]|\\{|\\}|\\\")'\n", + " new_annotation = []\n", + " for i,(word, tag) in enumerate(annotation):\n", + " if re.search(ps, word):\n", + " # find the ocurrences of the punctuation signs\n", + " occurrences = re.finditer(ps, word)\n", + " indexes = [(match.start(), match.end()) for match in occurrences]\n", + " # create the new tokens\n", + " last = 0\n", + " for j, (beg, end) in enumerate(indexes):\n", + " if beg > last:\n", + " new_annotation.append((word[last:beg], tag))\n", + " if tag != \"O\":\n", + " label = f'I-{tag.split(\"-\")[1]}'\n", + " else:\n", + " label = \"O\"\n", + " if end < len(word) or (i < len(annotation) - 1 and annotation[i+1][1] == label):\n", + " new_annotation.append((word[beg:end], label))\n", + " else:\n", + " new_annotation.append((word[beg:end], 'O')) \n", + " last = end\n", + " if last < len(word):\n", + " new_annotation.append((word[last:], label)) \n", + " \n", + " else:\n", + " new_annotation.append((word, tag))\n", + "\n", + " \n", + " return new_annotation" + ] + }, + { + "cell_type": "code", + "execution_count": 648, + "metadata": {}, + "outputs": [], + "source": [ + "pattern1 = r'<(Person|Condition|Value|Drug|Procedure|Measurement|Temporal|Observation|Mood|Pregnancy_considerations|Device)>'\n", + "pattern2 = r'</(Person|Condition|Value|Drug|Procedure|Measurement|Temporal|Observation|Mood|Pregnancy_considerations|Device)>'\n", + "pattern = f'{pattern1}.*?{pattern2}'" + ] + }, + { + "cell_type": "code", + "execution_count": 649, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "200" + ] + }, + "execution_count": 649, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# get BIO annotations for mistral outputs\n", + "for file in ann_files:\n", + " mistral_ann[file] = []\n", + " with open(f\"{annotations_path}/{file}\", \"r\") as f:\n", + " sentences = [line for line in f.readlines() if line != \"\\n\" and line != \" \\n\" and line != '']\n", + "\n", + " for sentence in sentences:\n", + " mistral_ann[file].append(parse_ann2bio(sentence, pattern, pattern1, pattern2))\n", + "len(mistral_ann)" + ] + }, + { + "cell_type": "code", + "execution_count": 650, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('Severely', 'O'),\n", + " ('to', 'O'),\n", + " ('isolate', 'O'),\n", + " ('for', 'O'),\n", + " ('procedure', 'B-Procedure')]" + ] + }, + "execution_count": 650, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sent = \"Severely to isolate for <Procedure>procedure</Procedure>.\"\n", + "parse_ann2bio(sent, pattern, pattern1, pattern2)" + ] + }, + { + "cell_type": "code", + "execution_count": 651, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "200" + ] + }, + "execution_count": 651, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# read real annotations from chia_bio\n", + "for file in ann_files:\n", + " true_ann[file] = []\n", + " with open(f\"{chia_bio_path}/{file}\", \"r\") as fd:\n", + " sentences_ann = fd.read().split(\"\\n\\n\")\n", + " sentences_ann = [sentence for sentence in sentences_ann if sentence != \"\" and sentence != '\\n']\n", + " for sentence in sentences_ann:\n", + " true_ann[file].append(sentence)\n", + "len(true_ann)" + ] + }, + { + "cell_type": "code", + "execution_count": 652, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Error in file NCT03132259_exc.bio.txt\n", + "True: 12, Mistral: 0\n", + "0.005\n" + ] + } + ], + "source": [ + "i = 0\n", + "corrupted_files = []\n", + "for file in ann_files:\n", + " if len(true_ann[file]) != len(mistral_ann[file]):\n", + " i += 1\n", + " print(f\"Error in file {file}\")\n", + " print(f\"True: {len(true_ann[file])}, Mistral: {len(mistral_ann[file])}\")\n", + " corrupted_files.append(file)\n", + "print(i/len(ann_files))" + ] + }, + { + "cell_type": "code", + "execution_count": 653, + "metadata": {}, + "outputs": [], + "source": [ + "# remove corructed file\n", + "for file in corrupted_files:\n", + " del true_ann[file]\n", + " del mistral_ann[file]\n", + " ann_files.remove(file)" + ] + }, + { + "cell_type": "code", + "execution_count": 654, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'NCT02322203_inc.bio.txt'" + ] + }, + "execution_count": 654, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ann_files[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 655, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "199" + ] + }, + "execution_count": 655, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "true_ann_aux = {}\n", + "\n", + "for file in ann_files:\n", + " true_ann_aux[file] = []\n", + " for i in range(len(true_ann[file])):\n", + " annotation = []\n", + " lines = true_ann[file][i].split(\"\\n\")\n", + " for line in lines:\n", + " if line != \"\":\n", + " spt_line = line.split()\n", + " annotation.append((spt_line[0], spt_line[-1]))\n", + " new_annotation = []\n", + " ps = r'(\\.|\\,|\\:|\\;|\\!|\\?|\\-|\\(|\\)|\\[|\\]|\\{|\\}|\\\")'\n", + " for i,(word, tag) in enumerate(annotation):\n", + " if re.search(ps, word):\n", + " # find the ocurrences of the punctuation signs\n", + " occurrences = re.finditer(ps, word)\n", + " indexes = [(match.start(), match.end()) for match in occurrences]\n", + " # create the new tokens\n", + " last = 0\n", + " for j, (beg, end) in enumerate(indexes):\n", + " if beg > last:\n", + " new_annotation.append((word[last:beg], tag))\n", + " if tag != \"O\":\n", + " label = f'I-{tag.split(\"-\")[1]}'\n", + " else:\n", + " label = \"O\"\n", + " if end < len(word) or (i < len(annotation) - 1 and annotation[i+1][1] == label):\n", + " new_annotation.append((word[beg:end], label))\n", + " else:\n", + " new_annotation.append((word[beg:end], 'O')) \n", + " last = end\n", + " if last < len(word):\n", + " new_annotation.append((word[last:], label))\n", + " else:\n", + " new_annotation.append((word, tag))\n", + " true_ann_aux[file].append(new_annotation)\n", + "true_ann = true_ann_aux\n", + "len(true_ann)" + ] + }, + { + "cell_type": "code", + "execution_count": 656, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('Subject', 'O'),\n", + " ('understands', 'O'),\n", + " ('the', 'O'),\n", + " ('investigational', 'O'),\n", + " ('nature', 'O'),\n", + " ('of', 'O'),\n", + " ('the', 'O'),\n", + " ('study', 'O'),\n", + " ('and', 'O'),\n", + " ('provides', 'O'),\n", + " ('written', 'O'),\n", + " (',', 'O'),\n", + " ('informed', 'B-Mood'),\n", + " ('consent', 'I-Mood'),\n", + " ('.', 'O')]" + ] + }, + "execution_count": 656, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mistral_ann[ann_files[0]][1]" + ] + }, + { + "cell_type": "code", + "execution_count": 658, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('Subject', 'O'),\n", + " ('understands', 'O'),\n", + " ('the', 'O'),\n", + " ('investigational', 'O'),\n", + " ('nature', 'O'),\n", + " ('of', 'O'),\n", + " ('the', 'O'),\n", + " ('study', 'O'),\n", + " ('and', 'O'),\n", + " ('provides', 'O'),\n", + " ('written', 'O'),\n", + " (',', 'O'),\n", + " ('informed', 'O'),\n", + " ('consent', 'O'),\n", + " ('.', 'O')]" + ] + }, + "execution_count": 658, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "true_ann[ann_files[0]][1]" + ] + }, + { + "cell_type": "code", + "execution_count": 661, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1147" + ] + }, + "execution_count": 661, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mistral_ann_dict = []\n", + "\n", + "for file in ann_files:\n", + " for i in range(len(mistral_ann[file])):\n", + " dict_sent = {\"tokens\": [], \"ner_tags\": [], \"file\": file, \"index\": i}\n", + " for word, tag in mistral_ann[file][i]:\n", + " dict_sent[\"tokens\"].append(word)\n", + " # add the int representation of the entity\n", + " dict_sent[\"ner_tags\"].append(sel_ent[tag])\n", + " mistral_ann_dict.append(dict_sent)\n", + "len(mistral_ann_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": 662, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1147" + ] + }, + "execution_count": 662, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "true_ann_dict = []\n", + "\n", + "for file in ann_files:\n", + " for i in range(len(true_ann[file])):\n", + " dict_sent = {\"tokens\": [], \"ner_tags\": [], \"file\": file, \"index\": i}\n", + " for word, tag in true_ann[file][i]:\n", + " dict_sent[\"tokens\"].append(word)\n", + " # add the int representation of the entity\n", + " dict_sent[\"ner_tags\"].append(sel_ent[tag])\n", + " true_ann_dict.append(dict_sent)\n", + "len(true_ann_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": 663, + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')" + ] + }, + { + "cell_type": "code", + "execution_count": 664, + "metadata": {}, + "outputs": [], + "source": [ + "# tokenize and align the labels in the dataset\n", + "def tokenize_and_align_labels(sentence, flag = 'I'):\n", + " \"\"\"\n", + " Tokenize the sentence and align the labels\n", + " inputs:\n", + " sentence: dict, the sentence from the dataset\n", + " flag: str, the flag to indicate how to deal with the labels for subwords\n", + " - 'I': use the label of the first subword for all subwords but as intermediate (I-ENT)\n", + " - 'B': use the label of the first subword for all subwords as beginning (B-ENT)\n", + " - None: use -100 for subwords\n", + " outputs:\n", + " tokenized_sentence: dict, the tokenized sentence now with a field for the labels\n", + " \"\"\"\n", + " tokenized_sentence = tokenizer(sentence['tokens'], is_split_into_words=True, truncation=True)\n", + "\n", + " labels = []\n", + " for i, labels_s in enumerate(sentence['ner_tags']):\n", + " word_ids = tokenized_sentence.word_ids(batch_index=i)\n", + " previous_word_idx = None\n", + " label_ids = []\n", + " for word_idx in word_ids:\n", + " # if the word_idx is None, assign -100\n", + " if word_idx is None:\n", + " label_ids.append(-100)\n", + " # if it is a new word, assign the corresponding label\n", + " elif word_idx != previous_word_idx:\n", + " label_ids.append(labels_s[word_idx])\n", + " # if it is the same word, check the flag to assign\n", + " else:\n", + " if flag == 'I':\n", + " if entities_list[labels_s[word_idx]].startswith('I'):\n", + " label_ids.append(labels_s[word_idx])\n", + " else:\n", + " label_ids.append(labels_s[word_idx] + 1)\n", + " elif flag == 'B':\n", + " label_ids.append(labels_s[word_idx])\n", + " elif flag == None:\n", + " label_ids.append(-100)\n", + " previous_word_idx = word_idx\n", + " labels.append(label_ids)\n", + " tokenized_sentence['labels'] = labels\n", + " return tokenized_sentence" + ] + }, + { + "cell_type": "code", + "execution_count": 665, + "metadata": {}, + "outputs": [], + "source": [ + "mis_df = pd.DataFrame(mistral_ann_dict)\n", + "true_df = pd.DataFrame(true_ann_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": 666, + "metadata": {}, + "outputs": [], + "source": [ + "mistral_ann_dataset = Dataset.from_pandas(mis_df)\n", + "true_ann_dataset = Dataset.from_pandas(true_df)" + ] + }, + { + "cell_type": "code", + "execution_count": 667, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Dataset({\n", + " features: ['tokens', 'ner_tags', 'file', 'index'],\n", + " num_rows: 1147\n", + " }),\n", + " Dataset({\n", + " features: ['tokens', 'ner_tags', 'file', 'index'],\n", + " num_rows: 1147\n", + " }))" + ] + }, + "execution_count": 667, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mistral_ann_dataset, true_ann_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 668, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map: 100%|██████████| 1147/1147 [00:00<00:00, 15088.51 examples/s]\n", + "Map: 100%|██████████| 1147/1147 [00:00<00:00, 18242.88 examples/s]\n" + ] + } + ], + "source": [ + "mistral_ann_dataset = mistral_ann_dataset.map(tokenize_and_align_labels, batched=True)\n", + "true_ann_dataset = true_ann_dataset.map(tokenize_and_align_labels, batched=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 669, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Dataset({\n", + " features: ['tokens', 'ner_tags', 'file', 'index', 'input_ids', 'attention_mask', 'labels'],\n", + " num_rows: 1147\n", + " }),\n", + " Dataset({\n", + " features: ['tokens', 'ner_tags', 'file', 'index', 'input_ids', 'attention_mask', 'labels'],\n", + " num_rows: 1147\n", + " }))" + ] + }, + "execution_count": 669, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mistral_ann_dataset, true_ann_dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Evaluation of the annotations made by Mistral using seqeval**" + ] + }, + { + "cell_type": "code", + "execution_count": 670, + "metadata": {}, + "outputs": [], + "source": [ + "seqeval = evaluate.load('seqeval')" + ] + }, + { + "cell_type": "code", + "execution_count": 671, + "metadata": {}, + "outputs": [], + "source": [ + "def compute_metrics(p):\n", + " \"\"\"\n", + " Compute the metrics for the model\n", + " inputs:\n", + " p: tuple, the predictions and the ground true\n", + " outputs:\n", + " dict: the metrics\n", + " \"\"\"\n", + " predictions, ground_true = p\n", + "\n", + " # Remove ignored index (special tokens)\n", + " predictions_labels = []\n", + " true_labels = []\n", + "\n", + " for preds, labels in zip(predictions, ground_true):\n", + " preds_labels = []\n", + " labels_true = []\n", + " for pred, label in zip(preds, labels):\n", + " if label != -100:\n", + " if pred == -100:\n", + " pred = 0\n", + " preds_labels.append(entities_list[pred])\n", + " labels_true.append(entities_list[label])\n", + " predictions_labels.append(preds_labels) \n", + " true_labels.append(labels_true)\n", + "\n", + " # predictions_labels = [\n", + " # [entities_list[p] for (p, l) in zip(prediction, ground_true) if l != -100]\n", + " # for prediction, label in zip(predictions, ground_true)\n", + " # ]\n", + " # true_labels = [\n", + " # [entities_list[l] for (p, l) in zip(prediction, ground_true) if l != -100]\n", + " # for prediction, label in zip(predictions, ground_true)\n", + " # ]\n", + " # print(predictions_labels[0])\n", + " # print(true_labels[0])\n", + "\n", + " results = seqeval.compute(predictions=predictions_labels, references=true_labels)\n", + " return results" + ] + }, + { + "cell_type": "code", + "execution_count": 672, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NCT00806273_exc.bio.txt 1\n", + "NCT00806273_exc.bio.txt\n" + ] + } + ], + "source": [ + "print(mistral_ann_dataset['file'][14], mistral_ann_dataset['index'][14])\n", + "print(true_ann_dataset['file'][14])" + ] + }, + { + "cell_type": "code", + "execution_count": 680, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['Men', 'and', 'women', '<Age>35', 'to', '70', 'years', 'of', 'age</Age>']\n", + "['Men', 'and', 'women', '35', 'to', '70', 'years', 'of', 'age']\n" + ] + } + ], + "source": [ + "print(mistral_ann_dataset['tokens'][1140])\n", + "print(true_ann_dataset['tokens'][1140])" + ] + }, + { + "cell_type": "code", + "execution_count": 674, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "14\n", + "15\n", + "24\n", + "33\n", + "37\n", + "57\n", + "71\n", + "72\n", + "90\n", + "92\n", + "101\n", + "102\n", + "109\n", + "110\n", + "111\n", + "127\n", + "142\n", + "155\n", + "162\n", + "169\n", + "176\n", + "188\n", + "194\n", + "201\n", + "205\n", + "209\n", + "210\n", + "226\n", + "229\n", + "230\n", + "231\n", + "232\n", + "233\n", + "241\n", + "244\n", + "251\n", + "271\n", + "273\n", + "274\n", + "307\n", + "310\n", + "320\n", + "324\n", + "332\n", + "337\n", + "339\n", + "349\n", + "350\n", + "363\n", + "369\n", + "372\n", + "383\n", + "403\n", + "404\n", + "408\n", + "412\n", + "417\n", + "421\n", + "432\n", + "447\n", + "448\n", + "463\n", + "464\n", + "465\n", + "470\n", + "475\n", + "487\n", + "491\n", + "500\n", + "512\n", + "513\n", + "530\n", + "536\n", + "548\n", + "567\n", + "573\n", + "578\n", + "581\n", + "596\n", + "600\n", + "611\n", + "623\n", + "631\n", + "633\n", + "646\n", + "654\n", + "655\n", + "656\n", + "671\n", + "681\n", + "686\n", + "694\n", + "707\n", + "708\n", + "719\n", + "725\n", + "727\n", + "731\n", + "742\n", + "745\n", + "756\n", + "757\n", + "759\n", + "763\n", + "767\n", + "769\n", + "777\n", + "780\n", + "782\n", + "783\n", + "788\n", + "789\n", + "794\n", + "795\n", + "800\n", + "801\n", + "805\n", + "806\n", + "809\n", + "816\n", + "817\n", + "820\n", + "822\n", + "823\n", + "828\n", + "831\n", + "837\n", + "855\n", + "869\n", + "873\n", + "877\n", + "903\n", + "912\n", + "924\n", + "943\n", + "945\n", + "959\n", + "966\n", + "972\n", + "974\n", + "977\n", + "980\n", + "981\n", + "983\n", + "994\n", + "998\n", + "999\n", + "1008\n", + "1010\n", + "1012\n", + "1013\n", + "1015\n", + "1016\n", + "1022\n", + "1026\n", + "1035\n", + "1047\n", + "1052\n", + "1060\n", + "1063\n", + "1066\n", + "1068\n", + "1076\n", + "1077\n", + "1082\n", + "1090\n", + "1093\n", + "1094\n", + "1099\n", + "1100\n", + "1111\n", + "1113\n", + "1119\n", + "1136\n", + "1140\n" + ] + } + ], + "source": [ + "for i in range(len(mistral_ann_dataset)):\n", + " if len(mistral_ann_dataset['labels'][i]) != len(true_ann_dataset['labels'][i]):\n", + " print(i)" + ] + }, + { + "cell_type": "code", + "execution_count": 675, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/miniconda3/envs/TER/lib/python3.10/site-packages/seqeval/metrics/v1.py:57: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, msg_start, len(result))\n" + ] + }, + { + "data": { + "text/plain": [ + "{'Condition': {'precision': 0.5477680433875678,\n", + " 'recall': 0.663466397170288,\n", + " 'f1': 0.6000914076782451,\n", + " 'number': 3958},\n", + " 'Device': {'precision': 0.2777777777777778,\n", + " 'recall': 0.15151515151515152,\n", + " 'f1': 0.19607843137254904,\n", + " 'number': 33},\n", + " 'Drug': {'precision': 0.5060975609756098,\n", + " 'recall': 0.5684931506849316,\n", + " 'f1': 0.535483870967742,\n", + " 'number': 292},\n", + " 'Measurement': {'precision': 0.13314447592067988,\n", + " 'recall': 0.14968152866242038,\n", + " 'f1': 0.1409295352323838,\n", + " 'number': 314},\n", + " 'Mood': {'precision': 0.00684931506849315,\n", + " 'recall': 0.01694915254237288,\n", + " 'f1': 0.00975609756097561,\n", + " 'number': 59},\n", + " 'Observation': {'precision': 0.05454545454545454,\n", + " 'recall': 0.019736842105263157,\n", + " 'f1': 0.028985507246376812,\n", + " 'number': 152},\n", + " 'Person': {'precision': 0.08108108108108109,\n", + " 'recall': 0.05056179775280899,\n", + " 'f1': 0.06228373702422144,\n", + " 'number': 178},\n", + " 'Pregnancy_considerations': {'precision': 0.0,\n", + " 'recall': 0.0,\n", + " 'f1': 0.0,\n", + " 'number': 19},\n", + " 'Procedure': {'precision': 0.3416149068322981,\n", + " 'recall': 0.3448275862068966,\n", + " 'f1': 0.343213728549142,\n", + " 'number': 319},\n", + " 'Temporal': {'precision': 0.03225806451612903,\n", + " 'recall': 0.008771929824561403,\n", + " 'f1': 0.01379310344827586,\n", + " 'number': 228},\n", + " 'Value': {'precision': 0.22857142857142856,\n", + " 'recall': 0.021164021164021163,\n", + " 'f1': 0.0387409200968523,\n", + " 'number': 378},\n", + " 'overall_precision': 0.4783097686375321,\n", + " 'overall_recall': 0.5020236087689713,\n", + " 'overall_f1': 0.4898798749382919,\n", + " 'overall_accuracy': 0.5684853881648183}" + ] + }, + "execution_count": 675, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "compute_metrics((mistral_ann_dataset['labels'], true_ann_dataset['labels']))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Kernel-env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}