[357738]: / LLM-Zero-shot_approach / evaluation.ipynb

Download this file

1062 lines (1061 with data), 29.2 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Evaluation of the Zero-shot approach for NER**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 641,
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports\n",
    "import os\n",
    "import torch\n",
    "from transformers import AutoTokenizer\n",
    "from datasets import Dataset, DatasetDict, load_metric\n",
    "import pandas as pd\n",
    "import evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 642,
   "metadata": {},
   "outputs": [],
   "source": [
    "# paths\n",
    "root_path = \"..\"\n",
    "# root_path = \"./drive/MyDrive/HandsOnNLP\" # for google colab\n",
    "data_path = f'{root_path}/data'\n",
    "annotations_path = f'{data_path}/Annotations_Mistral_Prompt_2'\n",
    "chia_bio_path = f'{data_path}/chia/chia_bio'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 643,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "200"
      ]
     },
     "execution_count": 643,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ann_files = os.listdir(annotations_path)\n",
    "len(ann_files)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 644,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_ann = {} # list with real annotations\n",
    "mistral_ann = {} # list with mistral annotations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 645,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dict for the entities (entity to int value)\n",
    "simple_ent = {\"Condition\", \"Value\", \"Drug\", \"Procedure\", \"Measurement\", \"Temporal\", \"Observation\", \"Person\", \"Mood\", \"Pregnancy_considerations\", \"Device\"}\n",
    "sel_ent = {\n",
    "    \"O\": 0,\n",
    "    \"B-Condition\": 1,\n",
    "    \"I-Condition\": 2,\n",
    "    \"B-Value\": 3,\n",
    "    \"I-Value\": 4,\n",
    "    \"B-Drug\": 5,\n",
    "    \"I-Drug\": 6,\n",
    "    \"B-Procedure\": 7,\n",
    "    \"I-Procedure\": 8,\n",
    "    \"B-Measurement\": 9,\n",
    "    \"I-Measurement\": 10,\n",
    "    \"B-Temporal\": 11,\n",
    "    \"I-Temporal\": 12,\n",
    "    \"B-Observation\": 13,\n",
    "    \"I-Observation\": 14,\n",
    "    \"B-Person\": 15,\n",
    "    \"I-Person\": 16,\n",
    "    \"B-Mood\": 17,\n",
    "    \"I-Mood\": 18,\n",
    "    \"B-Pregnancy_considerations\": 19,\n",
    "    \"I-Pregnancy_considerations\": 20,\n",
    "    \"B-Device\": 21,\n",
    "    \"I-Device\": 22\n",
    "}\n",
    "\n",
    "entities_list = list(sel_ent.keys())\n",
    "sel_ent_inv = {v: k for k, v in sel_ent.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 646,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 647,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_ann2bio(sentence, pattern, pattern1, pattern2):\n",
    "    if sentence[-1] == \"\\n\":\n",
    "        sentence = sentence[:-2] # remove the \\n and a final point wrongly added\n",
    "    else:\n",
    "        sentence = sentence[:-1] # remove the final point wrongly added\n",
    "    \n",
    "    # find the entities\n",
    "    occurrences = re.finditer(pattern, sentence)\n",
    "    indexes = [(match.start(), match.end()) for match in occurrences]\n",
    "\n",
    "    annotation = []\n",
    "    i = 0\n",
    "    # create the bio list\n",
    "    for beg, end in indexes:\n",
    "        if beg > i:\n",
    "            annotation.extend([(word, \"O\") for word in sentence[i:beg].split()])\n",
    "        entity = sentence[beg:end]\n",
    "        entity_name = re.search(pattern1, entity).group(1)\n",
    "        entity = entity.replace(f'<{entity_name}>', \"\").replace(f'</{entity_name}>', \"\")\n",
    "        split_entity = entity.split()\n",
    "        annotation.append((split_entity[0], \"B-\" + entity_name))\n",
    "        annotation.extend([(word, \"I-\" + entity_name) for word in split_entity[1:]])\n",
    "        i = end\n",
    "    annotation.extend([(word, \"O\") for word in sentence[i:].split()])\n",
    "\n",
    "    # check punctuation sign in tokens and put them as individual tokens\n",
    "    ps = r'(\\.|\\,|\\:|\\;|\\!|\\?|\\-|\\(|\\)|\\[|\\]|\\{|\\}|\\\")'\n",
    "    new_annotation = []\n",
    "    for i,(word, tag) in enumerate(annotation):\n",
    "        if re.search(ps, word):\n",
    "            # find the ocurrences of the punctuation signs\n",
    "            occurrences = re.finditer(ps, word)\n",
    "            indexes = [(match.start(), match.end()) for match in occurrences]\n",
    "            # create the new tokens\n",
    "            last = 0\n",
    "            for j, (beg, end) in enumerate(indexes):\n",
    "                if beg > last:\n",
    "                    new_annotation.append((word[last:beg], tag))\n",
    "                if tag != \"O\":\n",
    "                    label = f'I-{tag.split(\"-\")[1]}'\n",
    "                else:\n",
    "                    label = \"O\"\n",
    "                if end < len(word) or (i < len(annotation) - 1 and annotation[i+1][1] == label):\n",
    "                    new_annotation.append((word[beg:end], label))\n",
    "                else:\n",
    "                    new_annotation.append((word[beg:end], 'O')) \n",
    "                last = end\n",
    "            if last < len(word):\n",
    "                new_annotation.append((word[last:], label))   \n",
    "                \n",
    "        else:\n",
    "            new_annotation.append((word, tag))\n",
    "\n",
    "    \n",
    "    return new_annotation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 648,
   "metadata": {},
   "outputs": [],
   "source": [
    "pattern1 = r'<(Person|Condition|Value|Drug|Procedure|Measurement|Temporal|Observation|Mood|Pregnancy_considerations|Device)>'\n",
    "pattern2 = r'</(Person|Condition|Value|Drug|Procedure|Measurement|Temporal|Observation|Mood|Pregnancy_considerations|Device)>'\n",
    "pattern = f'{pattern1}.*?{pattern2}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 649,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "200"
      ]
     },
     "execution_count": 649,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# get BIO annotations for mistral outputs\n",
    "for file in ann_files:\n",
    "    mistral_ann[file] = []\n",
    "    with open(f\"{annotations_path}/{file}\", \"r\") as f:\n",
    "        sentences = [line for line in f.readlines() if line != \"\\n\" and line != \" \\n\" and line != '']\n",
    "\n",
    "    for sentence in sentences:\n",
    "        mistral_ann[file].append(parse_ann2bio(sentence, pattern, pattern1, pattern2))\n",
    "len(mistral_ann)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 650,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('Severely', 'O'),\n",
       " ('to', 'O'),\n",
       " ('isolate', 'O'),\n",
       " ('for', 'O'),\n",
       " ('procedure', 'B-Procedure')]"
      ]
     },
     "execution_count": 650,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sent = \"Severely to isolate for <Procedure>procedure</Procedure>.\"\n",
    "parse_ann2bio(sent, pattern, pattern1, pattern2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 651,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "200"
      ]
     },
     "execution_count": 651,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# read real annotations from chia_bio\n",
    "for file in ann_files:\n",
    "    true_ann[file] = []\n",
    "    with open(f\"{chia_bio_path}/{file}\", \"r\") as fd:\n",
    "        sentences_ann = fd.read().split(\"\\n\\n\")\n",
    "    sentences_ann = [sentence for sentence in sentences_ann if sentence != \"\" and sentence != '\\n']\n",
    "    for sentence in sentences_ann:\n",
    "        true_ann[file].append(sentence)\n",
    "len(true_ann)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 652,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error in file NCT03132259_exc.bio.txt\n",
      "True: 12, Mistral: 0\n",
      "0.005\n"
     ]
    }
   ],
   "source": [
    "i = 0\n",
    "corrupted_files = []\n",
    "for file in ann_files:\n",
    "    if len(true_ann[file]) != len(mistral_ann[file]):\n",
    "        i += 1\n",
    "        print(f\"Error in file {file}\")\n",
    "        print(f\"True: {len(true_ann[file])}, Mistral: {len(mistral_ann[file])}\")\n",
    "        corrupted_files.append(file)\n",
    "print(i/len(ann_files))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 653,
   "metadata": {},
   "outputs": [],
   "source": [
    "# remove corructed file\n",
    "for file in corrupted_files:\n",
    "    del true_ann[file]\n",
    "    del mistral_ann[file]\n",
    "    ann_files.remove(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 654,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'NCT02322203_inc.bio.txt'"
      ]
     },
     "execution_count": 654,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ann_files[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 655,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "199"
      ]
     },
     "execution_count": 655,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "true_ann_aux = {}\n",
    "\n",
    "for file in ann_files:\n",
    "    true_ann_aux[file] = []\n",
    "    for i in range(len(true_ann[file])):\n",
    "        annotation = []\n",
    "        lines = true_ann[file][i].split(\"\\n\")\n",
    "        for line in lines:\n",
    "            if line != \"\":\n",
    "                spt_line = line.split()\n",
    "                annotation.append((spt_line[0], spt_line[-1]))\n",
    "        new_annotation = []\n",
    "        ps = r'(\\.|\\,|\\:|\\;|\\!|\\?|\\-|\\(|\\)|\\[|\\]|\\{|\\}|\\\")'\n",
    "        for i,(word, tag) in enumerate(annotation):\n",
    "            if re.search(ps, word):\n",
    "                # find the ocurrences of the punctuation signs\n",
    "                occurrences = re.finditer(ps, word)\n",
    "                indexes = [(match.start(), match.end()) for match in occurrences]\n",
    "                # create the new tokens\n",
    "                last = 0\n",
    "                for j, (beg, end) in enumerate(indexes):\n",
    "                    if beg > last:\n",
    "                        new_annotation.append((word[last:beg], tag))\n",
    "                    if tag != \"O\":\n",
    "                        label = f'I-{tag.split(\"-\")[1]}'\n",
    "                    else:\n",
    "                        label = \"O\"\n",
    "                    if end < len(word) or (i < len(annotation) - 1 and annotation[i+1][1] == label):\n",
    "                        new_annotation.append((word[beg:end], label))\n",
    "                    else:\n",
    "                        new_annotation.append((word[beg:end], 'O')) \n",
    "                    last = end\n",
    "                if last < len(word):\n",
    "                    new_annotation.append((word[last:], label))\n",
    "            else:\n",
    "                new_annotation.append((word, tag))\n",
    "        true_ann_aux[file].append(new_annotation)\n",
    "true_ann = true_ann_aux\n",
    "len(true_ann)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 656,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('Subject', 'O'),\n",
       " ('understands', 'O'),\n",
       " ('the', 'O'),\n",
       " ('investigational', 'O'),\n",
       " ('nature', 'O'),\n",
       " ('of', 'O'),\n",
       " ('the', 'O'),\n",
       " ('study', 'O'),\n",
       " ('and', 'O'),\n",
       " ('provides', 'O'),\n",
       " ('written', 'O'),\n",
       " (',', 'O'),\n",
       " ('informed', 'B-Mood'),\n",
       " ('consent', 'I-Mood'),\n",
       " ('.', 'O')]"
      ]
     },
     "execution_count": 656,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mistral_ann[ann_files[0]][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 658,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('Subject', 'O'),\n",
       " ('understands', 'O'),\n",
       " ('the', 'O'),\n",
       " ('investigational', 'O'),\n",
       " ('nature', 'O'),\n",
       " ('of', 'O'),\n",
       " ('the', 'O'),\n",
       " ('study', 'O'),\n",
       " ('and', 'O'),\n",
       " ('provides', 'O'),\n",
       " ('written', 'O'),\n",
       " (',', 'O'),\n",
       " ('informed', 'O'),\n",
       " ('consent', 'O'),\n",
       " ('.', 'O')]"
      ]
     },
     "execution_count": 658,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "true_ann[ann_files[0]][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 661,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1147"
      ]
     },
     "execution_count": 661,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mistral_ann_dict = []\n",
    "\n",
    "for file in ann_files:\n",
    "    for i in range(len(mistral_ann[file])):\n",
    "        dict_sent = {\"tokens\": [], \"ner_tags\": [], \"file\": file, \"index\": i}\n",
    "        for word, tag in mistral_ann[file][i]:\n",
    "            dict_sent[\"tokens\"].append(word)\n",
    "            # add the int representation of the entity\n",
    "            dict_sent[\"ner_tags\"].append(sel_ent[tag])\n",
    "        mistral_ann_dict.append(dict_sent)\n",
    "len(mistral_ann_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 662,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1147"
      ]
     },
     "execution_count": 662,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "true_ann_dict = []\n",
    "\n",
    "for file in ann_files:\n",
    "    for i in range(len(true_ann[file])):\n",
    "        dict_sent = {\"tokens\": [], \"ner_tags\": [], \"file\": file, \"index\": i}\n",
    "        for word, tag in true_ann[file][i]:\n",
    "            dict_sent[\"tokens\"].append(word)\n",
    "            # add the int representation of the entity\n",
    "            dict_sent[\"ner_tags\"].append(sel_ent[tag])\n",
    "        true_ann_dict.append(dict_sent)\n",
    "len(true_ann_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 663,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 664,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tokenize and align the labels in the dataset\n",
    "def tokenize_and_align_labels(sentence, flag = 'I'):\n",
    "    \"\"\"\n",
    "    Tokenize the sentence and align the labels\n",
    "    inputs:\n",
    "        sentence: dict, the sentence from the dataset\n",
    "        flag: str, the flag to indicate how to deal with the labels for subwords\n",
    "            - 'I': use the label of the first subword for all subwords but as intermediate (I-ENT)\n",
    "            - 'B': use the label of the first subword for all subwords as beginning (B-ENT)\n",
    "            - None: use -100 for subwords\n",
    "    outputs:\n",
    "        tokenized_sentence: dict, the tokenized sentence now with a field for the labels\n",
    "    \"\"\"\n",
    "    tokenized_sentence = tokenizer(sentence['tokens'], is_split_into_words=True, truncation=True)\n",
    "\n",
    "    labels = []\n",
    "    for i, labels_s in enumerate(sentence['ner_tags']):\n",
    "        word_ids = tokenized_sentence.word_ids(batch_index=i)\n",
    "        previous_word_idx = None\n",
    "        label_ids = []\n",
    "        for word_idx in word_ids:\n",
    "            # if the word_idx is None, assign -100\n",
    "            if word_idx is None:\n",
    "                label_ids.append(-100)\n",
    "            # if it is a new word, assign the corresponding label\n",
    "            elif word_idx != previous_word_idx:\n",
    "                label_ids.append(labels_s[word_idx])\n",
    "            # if it is the same word, check the flag to assign\n",
    "            else:\n",
    "                if flag == 'I':\n",
    "                    if entities_list[labels_s[word_idx]].startswith('I'):\n",
    "                      label_ids.append(labels_s[word_idx])\n",
    "                    else:\n",
    "                      label_ids.append(labels_s[word_idx] + 1)\n",
    "                elif flag == 'B':\n",
    "                    label_ids.append(labels_s[word_idx])\n",
    "                elif flag == None:\n",
    "                    label_ids.append(-100)\n",
    "            previous_word_idx = word_idx\n",
    "        labels.append(label_ids)\n",
    "    tokenized_sentence['labels'] = labels\n",
    "    return tokenized_sentence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 665,
   "metadata": {},
   "outputs": [],
   "source": [
    "mis_df = pd.DataFrame(mistral_ann_dict)\n",
    "true_df = pd.DataFrame(true_ann_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 666,
   "metadata": {},
   "outputs": [],
   "source": [
    "mistral_ann_dataset = Dataset.from_pandas(mis_df)\n",
    "true_ann_dataset = Dataset.from_pandas(true_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 667,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Dataset({\n",
       "     features: ['tokens', 'ner_tags', 'file', 'index'],\n",
       "     num_rows: 1147\n",
       " }),\n",
       " Dataset({\n",
       "     features: ['tokens', 'ner_tags', 'file', 'index'],\n",
       "     num_rows: 1147\n",
       " }))"
      ]
     },
     "execution_count": 667,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mistral_ann_dataset, true_ann_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 668,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 1147/1147 [00:00<00:00, 15088.51 examples/s]\n",
      "Map: 100%|██████████| 1147/1147 [00:00<00:00, 18242.88 examples/s]\n"
     ]
    }
   ],
   "source": [
    "mistral_ann_dataset = mistral_ann_dataset.map(tokenize_and_align_labels, batched=True)\n",
    "true_ann_dataset = true_ann_dataset.map(tokenize_and_align_labels, batched=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 669,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Dataset({\n",
       "     features: ['tokens', 'ner_tags', 'file', 'index', 'input_ids', 'attention_mask', 'labels'],\n",
       "     num_rows: 1147\n",
       " }),\n",
       " Dataset({\n",
       "     features: ['tokens', 'ner_tags', 'file', 'index', 'input_ids', 'attention_mask', 'labels'],\n",
       "     num_rows: 1147\n",
       " }))"
      ]
     },
     "execution_count": 669,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mistral_ann_dataset, true_ann_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Evaluation of the annotations made by Mistral using seqeval**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 670,
   "metadata": {},
   "outputs": [],
   "source": [
    "seqeval = evaluate.load('seqeval')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 671,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_metrics(p):\n",
    "    \"\"\"\n",
    "    Compute the metrics for the model\n",
    "    inputs:\n",
    "        p: tuple, the predictions and the ground true\n",
    "    outputs:\n",
    "        dict: the metrics\n",
    "    \"\"\"\n",
    "    predictions, ground_true = p\n",
    "\n",
    "    # Remove ignored index (special tokens)\n",
    "    predictions_labels = []\n",
    "    true_labels = []\n",
    "\n",
    "    for preds, labels in zip(predictions, ground_true):\n",
    "        preds_labels = []\n",
    "        labels_true = []\n",
    "        for pred, label in zip(preds, labels):\n",
    "            if label != -100:\n",
    "                if pred == -100:\n",
    "                    pred = 0\n",
    "                preds_labels.append(entities_list[pred])\n",
    "                labels_true.append(entities_list[label])\n",
    "        predictions_labels.append(preds_labels) \n",
    "        true_labels.append(labels_true)\n",
    "\n",
    "    # predictions_labels = [\n",
    "    #     [entities_list[p] for (p, l) in zip(prediction, ground_true) if l != -100]\n",
    "    #     for prediction, label in zip(predictions, ground_true)\n",
    "    # ]\n",
    "    # true_labels = [\n",
    "    #     [entities_list[l] for (p, l) in zip(prediction, ground_true) if l != -100]\n",
    "    #     for prediction, label in zip(predictions, ground_true)\n",
    "    # ]\n",
    "    # print(predictions_labels[0])\n",
    "    # print(true_labels[0])\n",
    "\n",
    "    results = seqeval.compute(predictions=predictions_labels, references=true_labels)\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 672,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NCT00806273_exc.bio.txt 1\n",
      "NCT00806273_exc.bio.txt\n"
     ]
    }
   ],
   "source": [
    "print(mistral_ann_dataset['file'][14], mistral_ann_dataset['index'][14])\n",
    "print(true_ann_dataset['file'][14])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 680,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Men', 'and', 'women', '<Age>35', 'to', '70', 'years', 'of', 'age</Age>']\n",
      "['Men', 'and', 'women', '35', 'to', '70', 'years', 'of', 'age']\n"
     ]
    }
   ],
   "source": [
    "print(mistral_ann_dataset['tokens'][1140])\n",
    "print(true_ann_dataset['tokens'][1140])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 674,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "14\n",
      "15\n",
      "24\n",
      "33\n",
      "37\n",
      "57\n",
      "71\n",
      "72\n",
      "90\n",
      "92\n",
      "101\n",
      "102\n",
      "109\n",
      "110\n",
      "111\n",
      "127\n",
      "142\n",
      "155\n",
      "162\n",
      "169\n",
      "176\n",
      "188\n",
      "194\n",
      "201\n",
      "205\n",
      "209\n",
      "210\n",
      "226\n",
      "229\n",
      "230\n",
      "231\n",
      "232\n",
      "233\n",
      "241\n",
      "244\n",
      "251\n",
      "271\n",
      "273\n",
      "274\n",
      "307\n",
      "310\n",
      "320\n",
      "324\n",
      "332\n",
      "337\n",
      "339\n",
      "349\n",
      "350\n",
      "363\n",
      "369\n",
      "372\n",
      "383\n",
      "403\n",
      "404\n",
      "408\n",
      "412\n",
      "417\n",
      "421\n",
      "432\n",
      "447\n",
      "448\n",
      "463\n",
      "464\n",
      "465\n",
      "470\n",
      "475\n",
      "487\n",
      "491\n",
      "500\n",
      "512\n",
      "513\n",
      "530\n",
      "536\n",
      "548\n",
      "567\n",
      "573\n",
      "578\n",
      "581\n",
      "596\n",
      "600\n",
      "611\n",
      "623\n",
      "631\n",
      "633\n",
      "646\n",
      "654\n",
      "655\n",
      "656\n",
      "671\n",
      "681\n",
      "686\n",
      "694\n",
      "707\n",
      "708\n",
      "719\n",
      "725\n",
      "727\n",
      "731\n",
      "742\n",
      "745\n",
      "756\n",
      "757\n",
      "759\n",
      "763\n",
      "767\n",
      "769\n",
      "777\n",
      "780\n",
      "782\n",
      "783\n",
      "788\n",
      "789\n",
      "794\n",
      "795\n",
      "800\n",
      "801\n",
      "805\n",
      "806\n",
      "809\n",
      "816\n",
      "817\n",
      "820\n",
      "822\n",
      "823\n",
      "828\n",
      "831\n",
      "837\n",
      "855\n",
      "869\n",
      "873\n",
      "877\n",
      "903\n",
      "912\n",
      "924\n",
      "943\n",
      "945\n",
      "959\n",
      "966\n",
      "972\n",
      "974\n",
      "977\n",
      "980\n",
      "981\n",
      "983\n",
      "994\n",
      "998\n",
      "999\n",
      "1008\n",
      "1010\n",
      "1012\n",
      "1013\n",
      "1015\n",
      "1016\n",
      "1022\n",
      "1026\n",
      "1035\n",
      "1047\n",
      "1052\n",
      "1060\n",
      "1063\n",
      "1066\n",
      "1068\n",
      "1076\n",
      "1077\n",
      "1082\n",
      "1090\n",
      "1093\n",
      "1094\n",
      "1099\n",
      "1100\n",
      "1111\n",
      "1113\n",
      "1119\n",
      "1136\n",
      "1140\n"
     ]
    }
   ],
   "source": [
    "for i in range(len(mistral_ann_dataset)):\n",
    "    if len(mistral_ann_dataset['labels'][i]) != len(true_ann_dataset['labels'][i]):\n",
    "        print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 675,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/miniconda3/envs/TER/lib/python3.10/site-packages/seqeval/metrics/v1.py:57: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'Condition': {'precision': 0.5477680433875678,\n",
       "  'recall': 0.663466397170288,\n",
       "  'f1': 0.6000914076782451,\n",
       "  'number': 3958},\n",
       " 'Device': {'precision': 0.2777777777777778,\n",
       "  'recall': 0.15151515151515152,\n",
       "  'f1': 0.19607843137254904,\n",
       "  'number': 33},\n",
       " 'Drug': {'precision': 0.5060975609756098,\n",
       "  'recall': 0.5684931506849316,\n",
       "  'f1': 0.535483870967742,\n",
       "  'number': 292},\n",
       " 'Measurement': {'precision': 0.13314447592067988,\n",
       "  'recall': 0.14968152866242038,\n",
       "  'f1': 0.1409295352323838,\n",
       "  'number': 314},\n",
       " 'Mood': {'precision': 0.00684931506849315,\n",
       "  'recall': 0.01694915254237288,\n",
       "  'f1': 0.00975609756097561,\n",
       "  'number': 59},\n",
       " 'Observation': {'precision': 0.05454545454545454,\n",
       "  'recall': 0.019736842105263157,\n",
       "  'f1': 0.028985507246376812,\n",
       "  'number': 152},\n",
       " 'Person': {'precision': 0.08108108108108109,\n",
       "  'recall': 0.05056179775280899,\n",
       "  'f1': 0.06228373702422144,\n",
       "  'number': 178},\n",
       " 'Pregnancy_considerations': {'precision': 0.0,\n",
       "  'recall': 0.0,\n",
       "  'f1': 0.0,\n",
       "  'number': 19},\n",
       " 'Procedure': {'precision': 0.3416149068322981,\n",
       "  'recall': 0.3448275862068966,\n",
       "  'f1': 0.343213728549142,\n",
       "  'number': 319},\n",
       " 'Temporal': {'precision': 0.03225806451612903,\n",
       "  'recall': 0.008771929824561403,\n",
       "  'f1': 0.01379310344827586,\n",
       "  'number': 228},\n",
       " 'Value': {'precision': 0.22857142857142856,\n",
       "  'recall': 0.021164021164021163,\n",
       "  'f1': 0.0387409200968523,\n",
       "  'number': 378},\n",
       " 'overall_precision': 0.4783097686375321,\n",
       " 'overall_recall': 0.5020236087689713,\n",
       " 'overall_f1': 0.4898798749382919,\n",
       " 'overall_accuracy': 0.5684853881648183}"
      ]
     },
     "execution_count": 675,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "compute_metrics((mistral_ann_dataset['labels'], true_ann_dataset['labels']))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Kernel-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}