1062 lines (1061 with data), 29.2 kB
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Evaluation of the Zero-shot approach for NER**"
]
},
{
"cell_type": "code",
"execution_count": 641,
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"import os\n",
"import torch\n",
"from transformers import AutoTokenizer\n",
"from datasets import Dataset, DatasetDict, load_metric\n",
"import pandas as pd\n",
"import evaluate"
]
},
{
"cell_type": "code",
"execution_count": 642,
"metadata": {},
"outputs": [],
"source": [
"# paths\n",
"root_path = \"..\"\n",
"# root_path = \"./drive/MyDrive/HandsOnNLP\" # for google colab\n",
"data_path = f'{root_path}/data'\n",
"annotations_path = f'{data_path}/Annotations_Mistral_Prompt_2'\n",
"chia_bio_path = f'{data_path}/chia/chia_bio'"
]
},
{
"cell_type": "code",
"execution_count": 643,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"200"
]
},
"execution_count": 643,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ann_files = os.listdir(annotations_path)\n",
"len(ann_files)"
]
},
{
"cell_type": "code",
"execution_count": 644,
"metadata": {},
"outputs": [],
"source": [
"true_ann = {} # list with real annotations\n",
"mistral_ann = {} # list with mistral annotations"
]
},
{
"cell_type": "code",
"execution_count": 645,
"metadata": {},
"outputs": [],
"source": [
"# dict for the entities (entity to int value)\n",
"simple_ent = {\"Condition\", \"Value\", \"Drug\", \"Procedure\", \"Measurement\", \"Temporal\", \"Observation\", \"Person\", \"Mood\", \"Pregnancy_considerations\", \"Device\"}\n",
"sel_ent = {\n",
" \"O\": 0,\n",
" \"B-Condition\": 1,\n",
" \"I-Condition\": 2,\n",
" \"B-Value\": 3,\n",
" \"I-Value\": 4,\n",
" \"B-Drug\": 5,\n",
" \"I-Drug\": 6,\n",
" \"B-Procedure\": 7,\n",
" \"I-Procedure\": 8,\n",
" \"B-Measurement\": 9,\n",
" \"I-Measurement\": 10,\n",
" \"B-Temporal\": 11,\n",
" \"I-Temporal\": 12,\n",
" \"B-Observation\": 13,\n",
" \"I-Observation\": 14,\n",
" \"B-Person\": 15,\n",
" \"I-Person\": 16,\n",
" \"B-Mood\": 17,\n",
" \"I-Mood\": 18,\n",
" \"B-Pregnancy_considerations\": 19,\n",
" \"I-Pregnancy_considerations\": 20,\n",
" \"B-Device\": 21,\n",
" \"I-Device\": 22\n",
"}\n",
"\n",
"entities_list = list(sel_ent.keys())\n",
"sel_ent_inv = {v: k for k, v in sel_ent.items()}"
]
},
{
"cell_type": "code",
"execution_count": 646,
"metadata": {},
"outputs": [],
"source": [
"import re"
]
},
{
"cell_type": "code",
"execution_count": 647,
"metadata": {},
"outputs": [],
"source": [
"def parse_ann2bio(sentence, pattern, pattern1, pattern2):\n",
" if sentence[-1] == \"\\n\":\n",
" sentence = sentence[:-2] # remove the \\n and a final point wrongly added\n",
" else:\n",
" sentence = sentence[:-1] # remove the final point wrongly added\n",
" \n",
" # find the entities\n",
" occurrences = re.finditer(pattern, sentence)\n",
" indexes = [(match.start(), match.end()) for match in occurrences]\n",
"\n",
" annotation = []\n",
" i = 0\n",
" # create the bio list\n",
" for beg, end in indexes:\n",
" if beg > i:\n",
" annotation.extend([(word, \"O\") for word in sentence[i:beg].split()])\n",
" entity = sentence[beg:end]\n",
" entity_name = re.search(pattern1, entity).group(1)\n",
" entity = entity.replace(f'<{entity_name}>', \"\").replace(f'</{entity_name}>', \"\")\n",
" split_entity = entity.split()\n",
" annotation.append((split_entity[0], \"B-\" + entity_name))\n",
" annotation.extend([(word, \"I-\" + entity_name) for word in split_entity[1:]])\n",
" i = end\n",
" annotation.extend([(word, \"O\") for word in sentence[i:].split()])\n",
"\n",
" # check punctuation sign in tokens and put them as individual tokens\n",
" ps = r'(\\.|\\,|\\:|\\;|\\!|\\?|\\-|\\(|\\)|\\[|\\]|\\{|\\}|\\\")'\n",
" new_annotation = []\n",
" for i,(word, tag) in enumerate(annotation):\n",
" if re.search(ps, word):\n",
" # find the ocurrences of the punctuation signs\n",
" occurrences = re.finditer(ps, word)\n",
" indexes = [(match.start(), match.end()) for match in occurrences]\n",
" # create the new tokens\n",
" last = 0\n",
" for j, (beg, end) in enumerate(indexes):\n",
" if beg > last:\n",
" new_annotation.append((word[last:beg], tag))\n",
" if tag != \"O\":\n",
" label = f'I-{tag.split(\"-\")[1]}'\n",
" else:\n",
" label = \"O\"\n",
" if end < len(word) or (i < len(annotation) - 1 and annotation[i+1][1] == label):\n",
" new_annotation.append((word[beg:end], label))\n",
" else:\n",
" new_annotation.append((word[beg:end], 'O')) \n",
" last = end\n",
" if last < len(word):\n",
" new_annotation.append((word[last:], label)) \n",
" \n",
" else:\n",
" new_annotation.append((word, tag))\n",
"\n",
" \n",
" return new_annotation"
]
},
{
"cell_type": "code",
"execution_count": 648,
"metadata": {},
"outputs": [],
"source": [
"pattern1 = r'<(Person|Condition|Value|Drug|Procedure|Measurement|Temporal|Observation|Mood|Pregnancy_considerations|Device)>'\n",
"pattern2 = r'</(Person|Condition|Value|Drug|Procedure|Measurement|Temporal|Observation|Mood|Pregnancy_considerations|Device)>'\n",
"pattern = f'{pattern1}.*?{pattern2}'"
]
},
{
"cell_type": "code",
"execution_count": 649,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"200"
]
},
"execution_count": 649,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# get BIO annotations for mistral outputs\n",
"for file in ann_files:\n",
" mistral_ann[file] = []\n",
" with open(f\"{annotations_path}/{file}\", \"r\") as f:\n",
" sentences = [line for line in f.readlines() if line != \"\\n\" and line != \" \\n\" and line != '']\n",
"\n",
" for sentence in sentences:\n",
" mistral_ann[file].append(parse_ann2bio(sentence, pattern, pattern1, pattern2))\n",
"len(mistral_ann)"
]
},
{
"cell_type": "code",
"execution_count": 650,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('Severely', 'O'),\n",
" ('to', 'O'),\n",
" ('isolate', 'O'),\n",
" ('for', 'O'),\n",
" ('procedure', 'B-Procedure')]"
]
},
"execution_count": 650,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sent = \"Severely to isolate for <Procedure>procedure</Procedure>.\"\n",
"parse_ann2bio(sent, pattern, pattern1, pattern2)"
]
},
{
"cell_type": "code",
"execution_count": 651,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"200"
]
},
"execution_count": 651,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# read real annotations from chia_bio\n",
"for file in ann_files:\n",
" true_ann[file] = []\n",
" with open(f\"{chia_bio_path}/{file}\", \"r\") as fd:\n",
" sentences_ann = fd.read().split(\"\\n\\n\")\n",
" sentences_ann = [sentence for sentence in sentences_ann if sentence != \"\" and sentence != '\\n']\n",
" for sentence in sentences_ann:\n",
" true_ann[file].append(sentence)\n",
"len(true_ann)"
]
},
{
"cell_type": "code",
"execution_count": 652,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Error in file NCT03132259_exc.bio.txt\n",
"True: 12, Mistral: 0\n",
"0.005\n"
]
}
],
"source": [
"i = 0\n",
"corrupted_files = []\n",
"for file in ann_files:\n",
" if len(true_ann[file]) != len(mistral_ann[file]):\n",
" i += 1\n",
" print(f\"Error in file {file}\")\n",
" print(f\"True: {len(true_ann[file])}, Mistral: {len(mistral_ann[file])}\")\n",
" corrupted_files.append(file)\n",
"print(i/len(ann_files))"
]
},
{
"cell_type": "code",
"execution_count": 653,
"metadata": {},
"outputs": [],
"source": [
"# remove corructed file\n",
"for file in corrupted_files:\n",
" del true_ann[file]\n",
" del mistral_ann[file]\n",
" ann_files.remove(file)"
]
},
{
"cell_type": "code",
"execution_count": 654,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'NCT02322203_inc.bio.txt'"
]
},
"execution_count": 654,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ann_files[0]"
]
},
{
"cell_type": "code",
"execution_count": 655,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"199"
]
},
"execution_count": 655,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"true_ann_aux = {}\n",
"\n",
"for file in ann_files:\n",
" true_ann_aux[file] = []\n",
" for i in range(len(true_ann[file])):\n",
" annotation = []\n",
" lines = true_ann[file][i].split(\"\\n\")\n",
" for line in lines:\n",
" if line != \"\":\n",
" spt_line = line.split()\n",
" annotation.append((spt_line[0], spt_line[-1]))\n",
" new_annotation = []\n",
" ps = r'(\\.|\\,|\\:|\\;|\\!|\\?|\\-|\\(|\\)|\\[|\\]|\\{|\\}|\\\")'\n",
" for i,(word, tag) in enumerate(annotation):\n",
" if re.search(ps, word):\n",
" # find the ocurrences of the punctuation signs\n",
" occurrences = re.finditer(ps, word)\n",
" indexes = [(match.start(), match.end()) for match in occurrences]\n",
" # create the new tokens\n",
" last = 0\n",
" for j, (beg, end) in enumerate(indexes):\n",
" if beg > last:\n",
" new_annotation.append((word[last:beg], tag))\n",
" if tag != \"O\":\n",
" label = f'I-{tag.split(\"-\")[1]}'\n",
" else:\n",
" label = \"O\"\n",
" if end < len(word) or (i < len(annotation) - 1 and annotation[i+1][1] == label):\n",
" new_annotation.append((word[beg:end], label))\n",
" else:\n",
" new_annotation.append((word[beg:end], 'O')) \n",
" last = end\n",
" if last < len(word):\n",
" new_annotation.append((word[last:], label))\n",
" else:\n",
" new_annotation.append((word, tag))\n",
" true_ann_aux[file].append(new_annotation)\n",
"true_ann = true_ann_aux\n",
"len(true_ann)"
]
},
{
"cell_type": "code",
"execution_count": 656,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('Subject', 'O'),\n",
" ('understands', 'O'),\n",
" ('the', 'O'),\n",
" ('investigational', 'O'),\n",
" ('nature', 'O'),\n",
" ('of', 'O'),\n",
" ('the', 'O'),\n",
" ('study', 'O'),\n",
" ('and', 'O'),\n",
" ('provides', 'O'),\n",
" ('written', 'O'),\n",
" (',', 'O'),\n",
" ('informed', 'B-Mood'),\n",
" ('consent', 'I-Mood'),\n",
" ('.', 'O')]"
]
},
"execution_count": 656,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mistral_ann[ann_files[0]][1]"
]
},
{
"cell_type": "code",
"execution_count": 658,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('Subject', 'O'),\n",
" ('understands', 'O'),\n",
" ('the', 'O'),\n",
" ('investigational', 'O'),\n",
" ('nature', 'O'),\n",
" ('of', 'O'),\n",
" ('the', 'O'),\n",
" ('study', 'O'),\n",
" ('and', 'O'),\n",
" ('provides', 'O'),\n",
" ('written', 'O'),\n",
" (',', 'O'),\n",
" ('informed', 'O'),\n",
" ('consent', 'O'),\n",
" ('.', 'O')]"
]
},
"execution_count": 658,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"true_ann[ann_files[0]][1]"
]
},
{
"cell_type": "code",
"execution_count": 661,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1147"
]
},
"execution_count": 661,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mistral_ann_dict = []\n",
"\n",
"for file in ann_files:\n",
" for i in range(len(mistral_ann[file])):\n",
" dict_sent = {\"tokens\": [], \"ner_tags\": [], \"file\": file, \"index\": i}\n",
" for word, tag in mistral_ann[file][i]:\n",
" dict_sent[\"tokens\"].append(word)\n",
" # add the int representation of the entity\n",
" dict_sent[\"ner_tags\"].append(sel_ent[tag])\n",
" mistral_ann_dict.append(dict_sent)\n",
"len(mistral_ann_dict)"
]
},
{
"cell_type": "code",
"execution_count": 662,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1147"
]
},
"execution_count": 662,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"true_ann_dict = []\n",
"\n",
"for file in ann_files:\n",
" for i in range(len(true_ann[file])):\n",
" dict_sent = {\"tokens\": [], \"ner_tags\": [], \"file\": file, \"index\": i}\n",
" for word, tag in true_ann[file][i]:\n",
" dict_sent[\"tokens\"].append(word)\n",
" # add the int representation of the entity\n",
" dict_sent[\"ner_tags\"].append(sel_ent[tag])\n",
" true_ann_dict.append(dict_sent)\n",
"len(true_ann_dict)"
]
},
{
"cell_type": "code",
"execution_count": 663,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')"
]
},
{
"cell_type": "code",
"execution_count": 664,
"metadata": {},
"outputs": [],
"source": [
"# tokenize and align the labels in the dataset\n",
"def tokenize_and_align_labels(sentence, flag = 'I'):\n",
" \"\"\"\n",
" Tokenize the sentence and align the labels\n",
" inputs:\n",
" sentence: dict, the sentence from the dataset\n",
" flag: str, the flag to indicate how to deal with the labels for subwords\n",
" - 'I': use the label of the first subword for all subwords but as intermediate (I-ENT)\n",
" - 'B': use the label of the first subword for all subwords as beginning (B-ENT)\n",
" - None: use -100 for subwords\n",
" outputs:\n",
" tokenized_sentence: dict, the tokenized sentence now with a field for the labels\n",
" \"\"\"\n",
" tokenized_sentence = tokenizer(sentence['tokens'], is_split_into_words=True, truncation=True)\n",
"\n",
" labels = []\n",
" for i, labels_s in enumerate(sentence['ner_tags']):\n",
" word_ids = tokenized_sentence.word_ids(batch_index=i)\n",
" previous_word_idx = None\n",
" label_ids = []\n",
" for word_idx in word_ids:\n",
" # if the word_idx is None, assign -100\n",
" if word_idx is None:\n",
" label_ids.append(-100)\n",
" # if it is a new word, assign the corresponding label\n",
" elif word_idx != previous_word_idx:\n",
" label_ids.append(labels_s[word_idx])\n",
" # if it is the same word, check the flag to assign\n",
" else:\n",
" if flag == 'I':\n",
" if entities_list[labels_s[word_idx]].startswith('I'):\n",
" label_ids.append(labels_s[word_idx])\n",
" else:\n",
" label_ids.append(labels_s[word_idx] + 1)\n",
" elif flag == 'B':\n",
" label_ids.append(labels_s[word_idx])\n",
" elif flag == None:\n",
" label_ids.append(-100)\n",
" previous_word_idx = word_idx\n",
" labels.append(label_ids)\n",
" tokenized_sentence['labels'] = labels\n",
" return tokenized_sentence"
]
},
{
"cell_type": "code",
"execution_count": 665,
"metadata": {},
"outputs": [],
"source": [
"mis_df = pd.DataFrame(mistral_ann_dict)\n",
"true_df = pd.DataFrame(true_ann_dict)"
]
},
{
"cell_type": "code",
"execution_count": 666,
"metadata": {},
"outputs": [],
"source": [
"mistral_ann_dataset = Dataset.from_pandas(mis_df)\n",
"true_ann_dataset = Dataset.from_pandas(true_df)"
]
},
{
"cell_type": "code",
"execution_count": 667,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Dataset({\n",
" features: ['tokens', 'ner_tags', 'file', 'index'],\n",
" num_rows: 1147\n",
" }),\n",
" Dataset({\n",
" features: ['tokens', 'ner_tags', 'file', 'index'],\n",
" num_rows: 1147\n",
" }))"
]
},
"execution_count": 667,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mistral_ann_dataset, true_ann_dataset"
]
},
{
"cell_type": "code",
"execution_count": 668,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Map: 100%|██████████| 1147/1147 [00:00<00:00, 15088.51 examples/s]\n",
"Map: 100%|██████████| 1147/1147 [00:00<00:00, 18242.88 examples/s]\n"
]
}
],
"source": [
"mistral_ann_dataset = mistral_ann_dataset.map(tokenize_and_align_labels, batched=True)\n",
"true_ann_dataset = true_ann_dataset.map(tokenize_and_align_labels, batched=True)"
]
},
{
"cell_type": "code",
"execution_count": 669,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Dataset({\n",
" features: ['tokens', 'ner_tags', 'file', 'index', 'input_ids', 'attention_mask', 'labels'],\n",
" num_rows: 1147\n",
" }),\n",
" Dataset({\n",
" features: ['tokens', 'ner_tags', 'file', 'index', 'input_ids', 'attention_mask', 'labels'],\n",
" num_rows: 1147\n",
" }))"
]
},
"execution_count": 669,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mistral_ann_dataset, true_ann_dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Evaluation of the annotations made by Mistral using seqeval**"
]
},
{
"cell_type": "code",
"execution_count": 670,
"metadata": {},
"outputs": [],
"source": [
"seqeval = evaluate.load('seqeval')"
]
},
{
"cell_type": "code",
"execution_count": 671,
"metadata": {},
"outputs": [],
"source": [
"def compute_metrics(p):\n",
" \"\"\"\n",
" Compute the metrics for the model\n",
" inputs:\n",
" p: tuple, the predictions and the ground true\n",
" outputs:\n",
" dict: the metrics\n",
" \"\"\"\n",
" predictions, ground_true = p\n",
"\n",
" # Remove ignored index (special tokens)\n",
" predictions_labels = []\n",
" true_labels = []\n",
"\n",
" for preds, labels in zip(predictions, ground_true):\n",
" preds_labels = []\n",
" labels_true = []\n",
" for pred, label in zip(preds, labels):\n",
" if label != -100:\n",
" if pred == -100:\n",
" pred = 0\n",
" preds_labels.append(entities_list[pred])\n",
" labels_true.append(entities_list[label])\n",
" predictions_labels.append(preds_labels) \n",
" true_labels.append(labels_true)\n",
"\n",
" # predictions_labels = [\n",
" # [entities_list[p] for (p, l) in zip(prediction, ground_true) if l != -100]\n",
" # for prediction, label in zip(predictions, ground_true)\n",
" # ]\n",
" # true_labels = [\n",
" # [entities_list[l] for (p, l) in zip(prediction, ground_true) if l != -100]\n",
" # for prediction, label in zip(predictions, ground_true)\n",
" # ]\n",
" # print(predictions_labels[0])\n",
" # print(true_labels[0])\n",
"\n",
" results = seqeval.compute(predictions=predictions_labels, references=true_labels)\n",
" return results"
]
},
{
"cell_type": "code",
"execution_count": 672,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"NCT00806273_exc.bio.txt 1\n",
"NCT00806273_exc.bio.txt\n"
]
}
],
"source": [
"print(mistral_ann_dataset['file'][14], mistral_ann_dataset['index'][14])\n",
"print(true_ann_dataset['file'][14])"
]
},
{
"cell_type": "code",
"execution_count": 680,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['Men', 'and', 'women', '<Age>35', 'to', '70', 'years', 'of', 'age</Age>']\n",
"['Men', 'and', 'women', '35', 'to', '70', 'years', 'of', 'age']\n"
]
}
],
"source": [
"print(mistral_ann_dataset['tokens'][1140])\n",
"print(true_ann_dataset['tokens'][1140])"
]
},
{
"cell_type": "code",
"execution_count": 674,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"14\n",
"15\n",
"24\n",
"33\n",
"37\n",
"57\n",
"71\n",
"72\n",
"90\n",
"92\n",
"101\n",
"102\n",
"109\n",
"110\n",
"111\n",
"127\n",
"142\n",
"155\n",
"162\n",
"169\n",
"176\n",
"188\n",
"194\n",
"201\n",
"205\n",
"209\n",
"210\n",
"226\n",
"229\n",
"230\n",
"231\n",
"232\n",
"233\n",
"241\n",
"244\n",
"251\n",
"271\n",
"273\n",
"274\n",
"307\n",
"310\n",
"320\n",
"324\n",
"332\n",
"337\n",
"339\n",
"349\n",
"350\n",
"363\n",
"369\n",
"372\n",
"383\n",
"403\n",
"404\n",
"408\n",
"412\n",
"417\n",
"421\n",
"432\n",
"447\n",
"448\n",
"463\n",
"464\n",
"465\n",
"470\n",
"475\n",
"487\n",
"491\n",
"500\n",
"512\n",
"513\n",
"530\n",
"536\n",
"548\n",
"567\n",
"573\n",
"578\n",
"581\n",
"596\n",
"600\n",
"611\n",
"623\n",
"631\n",
"633\n",
"646\n",
"654\n",
"655\n",
"656\n",
"671\n",
"681\n",
"686\n",
"694\n",
"707\n",
"708\n",
"719\n",
"725\n",
"727\n",
"731\n",
"742\n",
"745\n",
"756\n",
"757\n",
"759\n",
"763\n",
"767\n",
"769\n",
"777\n",
"780\n",
"782\n",
"783\n",
"788\n",
"789\n",
"794\n",
"795\n",
"800\n",
"801\n",
"805\n",
"806\n",
"809\n",
"816\n",
"817\n",
"820\n",
"822\n",
"823\n",
"828\n",
"831\n",
"837\n",
"855\n",
"869\n",
"873\n",
"877\n",
"903\n",
"912\n",
"924\n",
"943\n",
"945\n",
"959\n",
"966\n",
"972\n",
"974\n",
"977\n",
"980\n",
"981\n",
"983\n",
"994\n",
"998\n",
"999\n",
"1008\n",
"1010\n",
"1012\n",
"1013\n",
"1015\n",
"1016\n",
"1022\n",
"1026\n",
"1035\n",
"1047\n",
"1052\n",
"1060\n",
"1063\n",
"1066\n",
"1068\n",
"1076\n",
"1077\n",
"1082\n",
"1090\n",
"1093\n",
"1094\n",
"1099\n",
"1100\n",
"1111\n",
"1113\n",
"1119\n",
"1136\n",
"1140\n"
]
}
],
"source": [
"for i in range(len(mistral_ann_dataset)):\n",
" if len(mistral_ann_dataset['labels'][i]) != len(true_ann_dataset['labels'][i]):\n",
" print(i)"
]
},
{
"cell_type": "code",
"execution_count": 675,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/miniconda3/envs/TER/lib/python3.10/site-packages/seqeval/metrics/v1.py:57: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, msg_start, len(result))\n"
]
},
{
"data": {
"text/plain": [
"{'Condition': {'precision': 0.5477680433875678,\n",
" 'recall': 0.663466397170288,\n",
" 'f1': 0.6000914076782451,\n",
" 'number': 3958},\n",
" 'Device': {'precision': 0.2777777777777778,\n",
" 'recall': 0.15151515151515152,\n",
" 'f1': 0.19607843137254904,\n",
" 'number': 33},\n",
" 'Drug': {'precision': 0.5060975609756098,\n",
" 'recall': 0.5684931506849316,\n",
" 'f1': 0.535483870967742,\n",
" 'number': 292},\n",
" 'Measurement': {'precision': 0.13314447592067988,\n",
" 'recall': 0.14968152866242038,\n",
" 'f1': 0.1409295352323838,\n",
" 'number': 314},\n",
" 'Mood': {'precision': 0.00684931506849315,\n",
" 'recall': 0.01694915254237288,\n",
" 'f1': 0.00975609756097561,\n",
" 'number': 59},\n",
" 'Observation': {'precision': 0.05454545454545454,\n",
" 'recall': 0.019736842105263157,\n",
" 'f1': 0.028985507246376812,\n",
" 'number': 152},\n",
" 'Person': {'precision': 0.08108108108108109,\n",
" 'recall': 0.05056179775280899,\n",
" 'f1': 0.06228373702422144,\n",
" 'number': 178},\n",
" 'Pregnancy_considerations': {'precision': 0.0,\n",
" 'recall': 0.0,\n",
" 'f1': 0.0,\n",
" 'number': 19},\n",
" 'Procedure': {'precision': 0.3416149068322981,\n",
" 'recall': 0.3448275862068966,\n",
" 'f1': 0.343213728549142,\n",
" 'number': 319},\n",
" 'Temporal': {'precision': 0.03225806451612903,\n",
" 'recall': 0.008771929824561403,\n",
" 'f1': 0.01379310344827586,\n",
" 'number': 228},\n",
" 'Value': {'precision': 0.22857142857142856,\n",
" 'recall': 0.021164021164021163,\n",
" 'f1': 0.0387409200968523,\n",
" 'number': 378},\n",
" 'overall_precision': 0.4783097686375321,\n",
" 'overall_recall': 0.5020236087689713,\n",
" 'overall_f1': 0.4898798749382919,\n",
" 'overall_accuracy': 0.5684853881648183}"
]
},
"execution_count": 675,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"compute_metrics((mistral_ann_dataset['labels'], true_ann_dataset['labels']))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Kernel-env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}