--- a
+++ b/Roberta+LLM/ensemble_model.ipynb
@@ -0,0 +1,459 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# uncomment if working in colab\n",
+    "from google.colab import drive\n",
+    "drive.mount('/content/drive')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# uncomment if using colab\n",
+    "!pip install -q -U git+https://github.com/huggingface/transformers.git\n",
+    "!pip install -q -U datasets\n",
+    "!pip install -q -U git+https://github.com/huggingface/accelerate.git\n",
+    "!pip install seqeval\n",
+    "!pip install -q -U evaluate"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import numpy as np\n",
+    "from transformers import AutoTokenizer, AutoModelForTokenClassification, DataCollatorForTokenClassification,  Trainer, TrainingArguments, AutoModelForTextGeneration\n",
+    "from datasets import load_dataset, load_metric\n",
+    "import evaluate\n",
+    "import torch"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from huggingface_hub import notebook_login\n",
+    "\n",
+    "notebook_login()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# paths\n",
+    "# root = '..'\n",
+    "root = './drive/MyDrive/TER-LISN-2024'\n",
+    "data_path = f'{root}/data'\n",
+    "model_path = f'{root}/models'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# dict for the entities (entity to int value)\n",
+    "simple_ent = {\"Condition\", \"Value\", \"Drug\", \"Procedure\", \"Measurement\", \"Temporal\", \"Observation\", \"Person\", \"Device\"}\n",
+    "sel_ent = {\n",
+    "    \"O\": 0,\n",
+    "    \"B-Condition\": 1,\n",
+    "    \"I-Condition\": 2,\n",
+    "    \"B-Value\": 3,\n",
+    "    \"I-Value\": 4,\n",
+    "    \"B-Drug\": 5,\n",
+    "    \"I-Drug\": 6,\n",
+    "    \"B-Procedure\": 7,\n",
+    "    \"I-Procedure\": 8,\n",
+    "    \"B-Measurement\": 9,\n",
+    "    \"I-Measurement\": 10,\n",
+    "    \"B-Temporal\": 11,\n",
+    "    \"I-Temporal\": 12,\n",
+    "    \"B-Observation\": 13,\n",
+    "    \"I-Observation\": 14,\n",
+    "    \"B-Person\": 15,\n",
+    "    \"I-Person\": 16,\n",
+    "    \"B-Device\": 17,\n",
+    "    \"I-Device\": 18\n",
+    "}\n",
+    "\n",
+    "entities_list = list(sel_ent.keys())\n",
+    "sel_ent_inv = {v: k for k, v in sel_ent.items()}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class EnsembleModelNER:\n",
+    "    def __init__(self, ner_model_name, llm_name, ner_from_local=True, path_to_model = None, llm_from_local=False, device='cpu'):\n",
+    "        self.ner_model_name = ner_model_name\n",
+    "        self.llm_name = llm_name\n",
+    "        self.ner_from_local = ner_from_local\n",
+    "        self.llm_from_local = llm_from_local\n",
+    "        self.ner_tokenizer = AutoTokenizer.from_pretrained(self.ner_model_name)\n",
+    "        self.llm_tokenizer = AutoTokenizer.from_pretrained(self.llm_name)\n",
+    "        if self.ner_from_local:\n",
+    "            self.ner_model = torch.load(path_to_model)\n",
+    "        else:\n",
+    "            self.ner_model = AutoModelForTokenClassification.from_pretrained(self.ner_model_name)\n",
+    "        self.llm_model = AutoModelForTextGeneration.from_pretrained(self.llm_name)\n",
+    "        self.device = device\n",
+    "    \n",
+    "    # tokenize and align the labels in the dataset\n",
+    "    def _tokenize_and_align_labels(self, sentence, labels_s, flag = 'I'):\n",
+    "        \"\"\"\n",
+    "        Tokenize the sentence and align the labels\n",
+    "        inputs:\n",
+    "            sentence: dict, the sentence from the dataset\n",
+    "            flag: str, the flag to indicate how to deal with the labels for subwords\n",
+    "                - 'I': use the label of the first subword for all subwords but as intermediate (I-ENT)\n",
+    "                - 'B': use the label of the first subword for all subwords as beginning (B-ENT)\n",
+    "                - None: use -100 for subwords\n",
+    "        outputs:\n",
+    "            tokenized_sentence: dict, the tokenized sentence now with a field for the labels\n",
+    "        \"\"\"\n",
+    "        tokenized_sentence = tokenizer(sentence['tokens'], is_split_into_words=True, truncation=True)\n",
+    "\n",
+    "        labels = []\n",
+    "        for i, labels_s in enumerate(sentence['ner_tags']):\n",
+    "            word_ids = tokenized_sentence.word_ids(batch_index=i)\n",
+    "            previous_word_idx = None\n",
+    "            label_ids = []\n",
+    "            for word_idx in word_ids:\n",
+    "                # if the word_idx is None, assign -100\n",
+    "                if word_idx is None:\n",
+    "                    label_ids.append(-100)\n",
+    "                # if it is a new word, assign the corresponding label\n",
+    "                elif word_idx != previous_word_idx:\n",
+    "                    label_ids.append(labels_s[word_idx])\n",
+    "                # if it is the same word, check the flag to assign\n",
+    "                else:\n",
+    "                    if flag == 'I':\n",
+    "                        if entities_list[labels_s[word_idx]].startswith('I'):\n",
+    "                          label_ids.append(labels_s[word_idx])\n",
+    "                        else:\n",
+    "                          label_ids.append(labels_s[word_idx] + 1)\n",
+    "                    elif flag == 'B':\n",
+    "                        label_ids.append(labels_s[word_idx])\n",
+    "                    elif flag == None:\n",
+    "                        label_ids.append(-100)\n",
+    "                previous_word_idx = word_idx\n",
+    "            labels.append(label_ids)\n",
+    "        tokenized_sentence['labels'] = labels\n",
+    "        return tokenized_sentence\n",
+    "    def annotate_sentences(dataset, labels, entities_list,criteria = 'first_label'):\n",
+    "        \"\"\"\n",
+    "        Annotate the sentences with the predicted labels\n",
+    "        inputs:\n",
+    "            dataset: dataset, dataset with the sentences\n",
+    "            labels: list, list of labels\n",
+    "            entities_list: list, list of entities\n",
+    "            criteria: str, criteria to use to select the label when the words pices have different labels\n",
+    "                - first_label: select the first label\n",
+    "                - majority: select the label with the majority\n",
+    "        outputs:\n",
+    "            annotated_sentences: list, list of annotated sentences\n",
+    "        \"\"\"\n",
+    "        annotated_sentences = []\n",
+    "        for i in range(len(dataset)):\n",
+    "            # get just the tokens different from None\n",
+    "            sentence = dataset[i]\n",
+    "            word_ids = sentence['word_ids']\n",
+    "            sentence_labels = labels[i]\n",
+    "            annotated_sentence = [[] for _ in range(len(dataset[i]['tokens']))]\n",
+    "            for word_id, label in zip(word_ids, sentence_labels):\n",
+    "                if word_id is not None:\n",
+    "                    annotated_sentence[word_id].append(label)\n",
+    "            annotated_sentence_filtered = []\n",
+    "            if criteria == 'first_label':\n",
+    "                annotated_sentence_filtered = [annotated_sentence[i][0] for i in range(len(annotated_sentence))]\n",
+    "            elif criteria == 'majority':\n",
+    "                annotated_sentence_filtered = [max(set(annotated_sentence[i]), key=annotated_sentence[i].count) for i in range(len(annotated_sentence))]\n",
+    "\n",
+    "            annotated_sentences.append(annotated_sentence_filtered)\n",
+    "        return annotated_sentences\n",
+    "\n",
+    "    def annotate_with_NER_model(self, dataset, entities_list):\n",
+    "        \"\"\"\n",
+    "        Annotate the dataset with the NER model\n",
+    "        inputs:\n",
+    "            dataset: dataset, the dataset to annotate\n",
+    "            entities_list: list, the list of labels\n",
+    "        outputs:\n",
+    "            annotated_dataset: dataset, the annotated dataset\n",
+    "        \"\"\"\n",
+    "        # tokenize and align the labels\n",
+    "        tokenized_dataset = dataset.map(lambda x: self._tokenize_and_align_labels(x, labels_s))\n",
+    "        # prepare the dataset for the model\n",
+    "        test_dataset = dataset['test']\n",
+    "\n",
+    "        data_for_model = test_dataset.remove_columns(['file', 'tokens', 'word_ids'])\n",
+    "\n",
+    "        data_loader = torch.utils.data.DataLoader(data_for_model, batch_size=16)\n",
+    "        \n",
+    "        self.ner_model.to(self.device)\n",
+    "        # predict the NER tags\n",
+    "        labels = []\n",
+    "        for batch in tqdm(data_loader):\n",
+    "        \n",
+    "            batch['input_ids'] = torch.LongTensor(np.column_stack(np.array(batch['input_ids']))).to(device)\n",
+    "            batch['attention_mask'] = torch.LongTensor(np.column_stack(np.array(batch['attention_mask']))).to(device)\n",
+    "            batch_tokenizer = {'input_ids': batch['input_ids'], 'attention_mask': batch['attention_mask']}\n",
+    "            # break\n",
+    "            with torch.no_grad():\n",
+    "                outputs = model(**batch_tokenizer)\n",
+    "\n",
+    "            labels_batch = torch.argmax(outputs.logits, dim=2).to('cpu').numpy()\n",
+    "            labels.extend([list(labels_batch[i]) for i in range(labels_batch.shape[0])])\n",
+    "\n",
+    "            del batch\n",
+    "            del outputs\n",
+    "            torch.cuda.empty_cache()\n",
+    "\n",
+    "        # recover original annotations split by words\n",
+    "        annotated_dataset = annotate_sentences(test_dataset, labels, entities_list)\n",
+    "        self.ner_model.to('cpu')\n",
+    "        if self.device != 'cpu':\n",
+    "            torch.cuda.empty_cache()\n",
+    "\n",
+    "        return annotated_dataset\n",
+    "\n",
+    "    def generate_prompts(self, annotated_sentences, entities_list):\n",
+    "        \"\"\"\n",
+    "        Generate the prompts for the LLM model\n",
+    "        inputs:\n",
+    "            annotated_sentences: list, the list of annotated sentences\n",
+    "        outputs:\n",
+    "            prompts: list, the list of prompts\n",
+    "        \"\"\"\n",
+    "        prompt_main = f\"\"\"I am working in a named entity recognition task for Clinical trial\n",
+    "        eligibility criteria. I have annotated a sentence with the NER model and I would like to\n",
+    "        check if the annotations are correct. The list of possible entities is {','.join(entities_list)}.\n",
+    "        Please keep the same BIO-format for annotations and do not change the words, just\n",
+    "        check the labels annontations. The sentence you must check is:\\n\\n\"\"\"\n",
+    "\n",
+    "        prompts = [prompt_main + '\\n'.join(annotated_sentences[i]) for i in range(len(annotated_sentences))]\n",
+    "        return prompts\n",
+    "\n",
+    "    def annotate_with_LLM_model(self, dataset, entities_list):\n",
+    "        \"\"\"\n",
+    "        Annotate the dataset with the LLM model\n",
+    "        inputs:\n",
+    "            dataset: dataset, the dataset to annotate\n",
+    "        outputs:\n",
+    "            annotated_dataset: dataset, the annotated dataset\n",
+    "        \"\"\"\n",
+    "        self.llm_model.to(self.device)\n",
+    "\n",
+    "        prompts = generate_prompts(dataset, entities_list)\n",
+    "\n",
+    "        # predict the NER tags\n",
+    "        llm_annotations = []\n",
+    "        for prompt in prompts:\n",
+    "            inputs = llm_tokenizer(prompt, return_tensors=\"pt\", padding=True, truncation=True, max_length=512).to(self.device)\n",
+    "            with torch.no_grad():\n",
+    "                outputs = llm_model.generate(**inputs)\n",
+    "            llm_annotations.append(llm_tokenizer.decode(outputs[0], skip_special_tokens=True))\n",
+    "        llm_model.to('cpu')\n",
+    "        if self.device != 'cpu':\n",
+    "            torch.cuda.empty_cache()\n",
+    "        return llm_annotations\n",
+    "\n",
+    "    def predict(self, dataset, entities_list):\n",
+    "        \"\"\"\n",
+    "        Predict the NER tags for the dataset\n",
+    "        inputs:\n",
+    "            dataset: dataset, the dataset to predict\n",
+    "            entities_list: list, the list of entities\n",
+    "        outputs:\n",
+    "            predictions: list, the list of predictions\n",
+    "        \"\"\"\n",
+    "\n",
+    "        # first step: annotate sentences with the NER model\n",
+    "        annotated_dataset = self.annotate_with_NER_model(dataset, entities_list)\n",
+    "\n",
+    "        # second step: use the annotated sentences as input for the LLM model to try\n",
+    "        # to improve the annotations\n",
+    "        annotated_sentences_after_llm = self.annotate_with_LLM_model(annotated_dataset, entities_list)\n",
+    "\n",
+    "        return annotated_sentences_after_llm\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "ner_model_name = 'roberta-base'\n",
+    "llm_name = 'BioMistral/BioMistral-7B'\n",
+    "ner_from_local = True\n",
+    "local_path = f'{model_path}/roberta-chia-ner.pt'\n",
+    "\n",
+    "# load the dataset\n",
+    "dataset = load_dataset('JavierLopetegui/chia_v1')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# load ensemble model\n",
+    "ensemble_model = EnsembleModelNER(ner_model_name, llm_name, ner_from_local, local_path, device=device)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "annotations = ensemble_model.predict(dataset, entities_list)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "annotations[0]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "annotations_entities = []\n",
+    "\n",
+    "for annotation in annotations:\n",
+    "    annotations_entities.append([int(a.split()[1]) for a in annotation])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def compute_metrics(p):\n",
+    "    \"\"\"\n",
+    "    Compute the metrics for the model\n",
+    "    inputs:\n",
+    "        p: tuple, the predictions and the ground true\n",
+    "    outputs:\n",
+    "        dict: the metrics\n",
+    "    \"\"\"\n",
+    "    predictions, ground_true = p\n",
+    "\n",
+    "    # Remove ignored index (special tokens)\n",
+    "    predictions_labels = []\n",
+    "    true_labels = []\n",
+    "\n",
+    "    for preds, labels in zip(predictions, ground_true):\n",
+    "        preds_labels = []\n",
+    "        labels_true = []\n",
+    "        for pred, label in zip(preds, labels):\n",
+    "            if label != -100:\n",
+    "                if pred == -100:\n",
+    "                    pred = 0\n",
+    "                preds_labels.append(entities_list[pred])\n",
+    "                labels_true.append(entities_list[label])\n",
+    "        predictions_labels.append(preds_labels) \n",
+    "        true_labels.append(labels_true)\n",
+    "\n",
+    "    # predictions_labels = [\n",
+    "    #     [entities_list[p] for (p, l) in zip(prediction, ground_true) if l != -100]\n",
+    "    #     for prediction, label in zip(predictions, ground_true)\n",
+    "    # ]\n",
+    "    # true_labels = [\n",
+    "    #     [entities_list[l] for (p, l) in zip(prediction, ground_true) if l != -100]\n",
+    "    #     for prediction, label in zip(predictions, ground_true)\n",
+    "    # ]\n",
+    "    # print(predictions_labels[0])\n",
+    "    # print(true_labels[0])\n",
+    "\n",
+    "    results = seqeval.compute(predictions=predictions_labels, references=true_labels)\n",
+    "    return results"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "metric = load_metric(\"seqeval\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# evaluate the model\n",
+    "results = metric.compute(predictions=annotations_entities, references=dataset['test']['ner_tags'])"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "TER",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.10.13"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}