Switch to unified view

a b/Roberta+LLM/ensemble_model.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": null,
6
   "metadata": {},
7
   "outputs": [],
8
   "source": [
9
    "# uncomment if working in colab\n",
10
    "from google.colab import drive\n",
11
    "drive.mount('/content/drive')"
12
   ]
13
  },
14
  {
15
   "cell_type": "code",
16
   "execution_count": null,
17
   "metadata": {},
18
   "outputs": [],
19
   "source": [
20
    "# uncomment if using colab\n",
21
    "!pip install -q -U git+https://github.com/huggingface/transformers.git\n",
22
    "!pip install -q -U datasets\n",
23
    "!pip install -q -U git+https://github.com/huggingface/accelerate.git\n",
24
    "!pip install seqeval\n",
25
    "!pip install -q -U evaluate"
26
   ]
27
  },
28
  {
29
   "cell_type": "code",
30
   "execution_count": null,
31
   "metadata": {},
32
   "outputs": [],
33
   "source": [
34
    "import numpy as np\n",
35
    "from transformers import AutoTokenizer, AutoModelForTokenClassification, DataCollatorForTokenClassification,  Trainer, TrainingArguments, AutoModelForTextGeneration\n",
36
    "from datasets import load_dataset, load_metric\n",
37
    "import evaluate\n",
38
    "import torch"
39
   ]
40
  },
41
  {
42
   "cell_type": "code",
43
   "execution_count": null,
44
   "metadata": {},
45
   "outputs": [],
46
   "source": [
47
    "from huggingface_hub import notebook_login\n",
48
    "\n",
49
    "notebook_login()"
50
   ]
51
  },
52
  {
53
   "cell_type": "code",
54
   "execution_count": 1,
55
   "metadata": {},
56
   "outputs": [],
57
   "source": [
58
    "# paths\n",
59
    "# root = '..'\n",
60
    "root = './drive/MyDrive/TER-LISN-2024'\n",
61
    "data_path = f'{root}/data'\n",
62
    "model_path = f'{root}/models'"
63
   ]
64
  },
65
  {
66
   "cell_type": "code",
67
   "execution_count": null,
68
   "metadata": {},
69
   "outputs": [],
70
   "source": [
71
    "# dict for the entities (entity to int value)\n",
72
    "simple_ent = {\"Condition\", \"Value\", \"Drug\", \"Procedure\", \"Measurement\", \"Temporal\", \"Observation\", \"Person\", \"Device\"}\n",
73
    "sel_ent = {\n",
74
    "    \"O\": 0,\n",
75
    "    \"B-Condition\": 1,\n",
76
    "    \"I-Condition\": 2,\n",
77
    "    \"B-Value\": 3,\n",
78
    "    \"I-Value\": 4,\n",
79
    "    \"B-Drug\": 5,\n",
80
    "    \"I-Drug\": 6,\n",
81
    "    \"B-Procedure\": 7,\n",
82
    "    \"I-Procedure\": 8,\n",
83
    "    \"B-Measurement\": 9,\n",
84
    "    \"I-Measurement\": 10,\n",
85
    "    \"B-Temporal\": 11,\n",
86
    "    \"I-Temporal\": 12,\n",
87
    "    \"B-Observation\": 13,\n",
88
    "    \"I-Observation\": 14,\n",
89
    "    \"B-Person\": 15,\n",
90
    "    \"I-Person\": 16,\n",
91
    "    \"B-Device\": 17,\n",
92
    "    \"I-Device\": 18\n",
93
    "}\n",
94
    "\n",
95
    "entities_list = list(sel_ent.keys())\n",
96
    "sel_ent_inv = {v: k for k, v in sel_ent.items()}"
97
   ]
98
  },
99
  {
100
   "cell_type": "code",
101
   "execution_count": null,
102
   "metadata": {},
103
   "outputs": [],
104
   "source": [
105
    "class EnsembleModelNER:\n",
106
    "    def __init__(self, ner_model_name, llm_name, ner_from_local=True, path_to_model = None, llm_from_local=False, device='cpu'):\n",
107
    "        self.ner_model_name = ner_model_name\n",
108
    "        self.llm_name = llm_name\n",
109
    "        self.ner_from_local = ner_from_local\n",
110
    "        self.llm_from_local = llm_from_local\n",
111
    "        self.ner_tokenizer = AutoTokenizer.from_pretrained(self.ner_model_name)\n",
112
    "        self.llm_tokenizer = AutoTokenizer.from_pretrained(self.llm_name)\n",
113
    "        if self.ner_from_local:\n",
114
    "            self.ner_model = torch.load(path_to_model)\n",
115
    "        else:\n",
116
    "            self.ner_model = AutoModelForTokenClassification.from_pretrained(self.ner_model_name)\n",
117
    "        self.llm_model = AutoModelForTextGeneration.from_pretrained(self.llm_name)\n",
118
    "        self.device = device\n",
119
    "    \n",
120
    "    # tokenize and align the labels in the dataset\n",
121
    "    def _tokenize_and_align_labels(self, sentence, labels_s, flag = 'I'):\n",
122
    "        \"\"\"\n",
123
    "        Tokenize the sentence and align the labels\n",
124
    "        inputs:\n",
125
    "            sentence: dict, the sentence from the dataset\n",
126
    "            flag: str, the flag to indicate how to deal with the labels for subwords\n",
127
    "                - 'I': use the label of the first subword for all subwords but as intermediate (I-ENT)\n",
128
    "                - 'B': use the label of the first subword for all subwords as beginning (B-ENT)\n",
129
    "                - None: use -100 for subwords\n",
130
    "        outputs:\n",
131
    "            tokenized_sentence: dict, the tokenized sentence now with a field for the labels\n",
132
    "        \"\"\"\n",
133
    "        tokenized_sentence = tokenizer(sentence['tokens'], is_split_into_words=True, truncation=True)\n",
134
    "\n",
135
    "        labels = []\n",
136
    "        for i, labels_s in enumerate(sentence['ner_tags']):\n",
137
    "            word_ids = tokenized_sentence.word_ids(batch_index=i)\n",
138
    "            previous_word_idx = None\n",
139
    "            label_ids = []\n",
140
    "            for word_idx in word_ids:\n",
141
    "                # if the word_idx is None, assign -100\n",
142
    "                if word_idx is None:\n",
143
    "                    label_ids.append(-100)\n",
144
    "                # if it is a new word, assign the corresponding label\n",
145
    "                elif word_idx != previous_word_idx:\n",
146
    "                    label_ids.append(labels_s[word_idx])\n",
147
    "                # if it is the same word, check the flag to assign\n",
148
    "                else:\n",
149
    "                    if flag == 'I':\n",
150
    "                        if entities_list[labels_s[word_idx]].startswith('I'):\n",
151
    "                          label_ids.append(labels_s[word_idx])\n",
152
    "                        else:\n",
153
    "                          label_ids.append(labels_s[word_idx] + 1)\n",
154
    "                    elif flag == 'B':\n",
155
    "                        label_ids.append(labels_s[word_idx])\n",
156
    "                    elif flag == None:\n",
157
    "                        label_ids.append(-100)\n",
158
    "                previous_word_idx = word_idx\n",
159
    "            labels.append(label_ids)\n",
160
    "        tokenized_sentence['labels'] = labels\n",
161
    "        return tokenized_sentence\n",
162
    "    def annotate_sentences(dataset, labels, entities_list,criteria = 'first_label'):\n",
163
    "        \"\"\"\n",
164
    "        Annotate the sentences with the predicted labels\n",
165
    "        inputs:\n",
166
    "            dataset: dataset, dataset with the sentences\n",
167
    "            labels: list, list of labels\n",
168
    "            entities_list: list, list of entities\n",
169
    "            criteria: str, criteria to use to select the label when the words pices have different labels\n",
170
    "                - first_label: select the first label\n",
171
    "                - majority: select the label with the majority\n",
172
    "        outputs:\n",
173
    "            annotated_sentences: list, list of annotated sentences\n",
174
    "        \"\"\"\n",
175
    "        annotated_sentences = []\n",
176
    "        for i in range(len(dataset)):\n",
177
    "            # get just the tokens different from None\n",
178
    "            sentence = dataset[i]\n",
179
    "            word_ids = sentence['word_ids']\n",
180
    "            sentence_labels = labels[i]\n",
181
    "            annotated_sentence = [[] for _ in range(len(dataset[i]['tokens']))]\n",
182
    "            for word_id, label in zip(word_ids, sentence_labels):\n",
183
    "                if word_id is not None:\n",
184
    "                    annotated_sentence[word_id].append(label)\n",
185
    "            annotated_sentence_filtered = []\n",
186
    "            if criteria == 'first_label':\n",
187
    "                annotated_sentence_filtered = [annotated_sentence[i][0] for i in range(len(annotated_sentence))]\n",
188
    "            elif criteria == 'majority':\n",
189
    "                annotated_sentence_filtered = [max(set(annotated_sentence[i]), key=annotated_sentence[i].count) for i in range(len(annotated_sentence))]\n",
190
    "\n",
191
    "            annotated_sentences.append(annotated_sentence_filtered)\n",
192
    "        return annotated_sentences\n",
193
    "\n",
194
    "    def annotate_with_NER_model(self, dataset, entities_list):\n",
195
    "        \"\"\"\n",
196
    "        Annotate the dataset with the NER model\n",
197
    "        inputs:\n",
198
    "            dataset: dataset, the dataset to annotate\n",
199
    "            entities_list: list, the list of labels\n",
200
    "        outputs:\n",
201
    "            annotated_dataset: dataset, the annotated dataset\n",
202
    "        \"\"\"\n",
203
    "        # tokenize and align the labels\n",
204
    "        tokenized_dataset = dataset.map(lambda x: self._tokenize_and_align_labels(x, labels_s))\n",
205
    "        # prepare the dataset for the model\n",
206
    "        test_dataset = dataset['test']\n",
207
    "\n",
208
    "        data_for_model = test_dataset.remove_columns(['file', 'tokens', 'word_ids'])\n",
209
    "\n",
210
    "        data_loader = torch.utils.data.DataLoader(data_for_model, batch_size=16)\n",
211
    "        \n",
212
    "        self.ner_model.to(self.device)\n",
213
    "        # predict the NER tags\n",
214
    "        labels = []\n",
215
    "        for batch in tqdm(data_loader):\n",
216
    "        \n",
217
    "            batch['input_ids'] = torch.LongTensor(np.column_stack(np.array(batch['input_ids']))).to(device)\n",
218
    "            batch['attention_mask'] = torch.LongTensor(np.column_stack(np.array(batch['attention_mask']))).to(device)\n",
219
    "            batch_tokenizer = {'input_ids': batch['input_ids'], 'attention_mask': batch['attention_mask']}\n",
220
    "            # break\n",
221
    "            with torch.no_grad():\n",
222
    "                outputs = model(**batch_tokenizer)\n",
223
    "\n",
224
    "            labels_batch = torch.argmax(outputs.logits, dim=2).to('cpu').numpy()\n",
225
    "            labels.extend([list(labels_batch[i]) for i in range(labels_batch.shape[0])])\n",
226
    "\n",
227
    "            del batch\n",
228
    "            del outputs\n",
229
    "            torch.cuda.empty_cache()\n",
230
    "\n",
231
    "        # recover original annotations split by words\n",
232
    "        annotated_dataset = annotate_sentences(test_dataset, labels, entities_list)\n",
233
    "        self.ner_model.to('cpu')\n",
234
    "        if self.device != 'cpu':\n",
235
    "            torch.cuda.empty_cache()\n",
236
    "\n",
237
    "        return annotated_dataset\n",
238
    "\n",
239
    "    def generate_prompts(self, annotated_sentences, entities_list):\n",
240
    "        \"\"\"\n",
241
    "        Generate the prompts for the LLM model\n",
242
    "        inputs:\n",
243
    "            annotated_sentences: list, the list of annotated sentences\n",
244
    "        outputs:\n",
245
    "            prompts: list, the list of prompts\n",
246
    "        \"\"\"\n",
247
    "        prompt_main = f\"\"\"I am working in a named entity recognition task for Clinical trial\n",
248
    "        eligibility criteria. I have annotated a sentence with the NER model and I would like to\n",
249
    "        check if the annotations are correct. The list of possible entities is {','.join(entities_list)}.\n",
250
    "        Please keep the same BIO-format for annotations and do not change the words, just\n",
251
    "        check the labels annontations. The sentence you must check is:\\n\\n\"\"\"\n",
252
    "\n",
253
    "        prompts = [prompt_main + '\\n'.join(annotated_sentences[i]) for i in range(len(annotated_sentences))]\n",
254
    "        return prompts\n",
255
    "\n",
256
    "    def annotate_with_LLM_model(self, dataset, entities_list):\n",
257
    "        \"\"\"\n",
258
    "        Annotate the dataset with the LLM model\n",
259
    "        inputs:\n",
260
    "            dataset: dataset, the dataset to annotate\n",
261
    "        outputs:\n",
262
    "            annotated_dataset: dataset, the annotated dataset\n",
263
    "        \"\"\"\n",
264
    "        self.llm_model.to(self.device)\n",
265
    "\n",
266
    "        prompts = generate_prompts(dataset, entities_list)\n",
267
    "\n",
268
    "        # predict the NER tags\n",
269
    "        llm_annotations = []\n",
270
    "        for prompt in prompts:\n",
271
    "            inputs = llm_tokenizer(prompt, return_tensors=\"pt\", padding=True, truncation=True, max_length=512).to(self.device)\n",
272
    "            with torch.no_grad():\n",
273
    "                outputs = llm_model.generate(**inputs)\n",
274
    "            llm_annotations.append(llm_tokenizer.decode(outputs[0], skip_special_tokens=True))\n",
275
    "        llm_model.to('cpu')\n",
276
    "        if self.device != 'cpu':\n",
277
    "            torch.cuda.empty_cache()\n",
278
    "        return llm_annotations\n",
279
    "\n",
280
    "    def predict(self, dataset, entities_list):\n",
281
    "        \"\"\"\n",
282
    "        Predict the NER tags for the dataset\n",
283
    "        inputs:\n",
284
    "            dataset: dataset, the dataset to predict\n",
285
    "            entities_list: list, the list of entities\n",
286
    "        outputs:\n",
287
    "            predictions: list, the list of predictions\n",
288
    "        \"\"\"\n",
289
    "\n",
290
    "        # first step: annotate sentences with the NER model\n",
291
    "        annotated_dataset = self.annotate_with_NER_model(dataset, entities_list)\n",
292
    "\n",
293
    "        # second step: use the annotated sentences as input for the LLM model to try\n",
294
    "        # to improve the annotations\n",
295
    "        annotated_sentences_after_llm = self.annotate_with_LLM_model(annotated_dataset, entities_list)\n",
296
    "\n",
297
    "        return annotated_sentences_after_llm\n"
298
   ]
299
  },
300
  {
301
   "cell_type": "code",
302
   "execution_count": null,
303
   "metadata": {},
304
   "outputs": [],
305
   "source": []
306
  },
307
  {
308
   "cell_type": "code",
309
   "execution_count": null,
310
   "metadata": {},
311
   "outputs": [],
312
   "source": [
313
    "ner_model_name = 'roberta-base'\n",
314
    "llm_name = 'BioMistral/BioMistral-7B'\n",
315
    "ner_from_local = True\n",
316
    "local_path = f'{model_path}/roberta-chia-ner.pt'\n",
317
    "\n",
318
    "# load the dataset\n",
319
    "dataset = load_dataset('JavierLopetegui/chia_v1')"
320
   ]
321
  },
322
  {
323
   "cell_type": "code",
324
   "execution_count": null,
325
   "metadata": {},
326
   "outputs": [],
327
   "source": [
328
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
329
   ]
330
  },
331
  {
332
   "cell_type": "code",
333
   "execution_count": null,
334
   "metadata": {},
335
   "outputs": [],
336
   "source": [
337
    "# load ensemble model\n",
338
    "ensemble_model = EnsembleModelNER(ner_model_name, llm_name, ner_from_local, local_path, device=device)"
339
   ]
340
  },
341
  {
342
   "cell_type": "code",
343
   "execution_count": null,
344
   "metadata": {},
345
   "outputs": [],
346
   "source": [
347
    "annotations = ensemble_model.predict(dataset, entities_list)"
348
   ]
349
  },
350
  {
351
   "cell_type": "code",
352
   "execution_count": null,
353
   "metadata": {},
354
   "outputs": [],
355
   "source": [
356
    "annotations[0]"
357
   ]
358
  },
359
  {
360
   "cell_type": "code",
361
   "execution_count": null,
362
   "metadata": {},
363
   "outputs": [],
364
   "source": [
365
    "annotations_entities = []\n",
366
    "\n",
367
    "for annotation in annotations:\n",
368
    "    annotations_entities.append([int(a.split()[1]) for a in annotation])"
369
   ]
370
  },
371
  {
372
   "cell_type": "code",
373
   "execution_count": null,
374
   "metadata": {},
375
   "outputs": [],
376
   "source": [
377
    "def compute_metrics(p):\n",
378
    "    \"\"\"\n",
379
    "    Compute the metrics for the model\n",
380
    "    inputs:\n",
381
    "        p: tuple, the predictions and the ground true\n",
382
    "    outputs:\n",
383
    "        dict: the metrics\n",
384
    "    \"\"\"\n",
385
    "    predictions, ground_true = p\n",
386
    "\n",
387
    "    # Remove ignored index (special tokens)\n",
388
    "    predictions_labels = []\n",
389
    "    true_labels = []\n",
390
    "\n",
391
    "    for preds, labels in zip(predictions, ground_true):\n",
392
    "        preds_labels = []\n",
393
    "        labels_true = []\n",
394
    "        for pred, label in zip(preds, labels):\n",
395
    "            if label != -100:\n",
396
    "                if pred == -100:\n",
397
    "                    pred = 0\n",
398
    "                preds_labels.append(entities_list[pred])\n",
399
    "                labels_true.append(entities_list[label])\n",
400
    "        predictions_labels.append(preds_labels) \n",
401
    "        true_labels.append(labels_true)\n",
402
    "\n",
403
    "    # predictions_labels = [\n",
404
    "    #     [entities_list[p] for (p, l) in zip(prediction, ground_true) if l != -100]\n",
405
    "    #     for prediction, label in zip(predictions, ground_true)\n",
406
    "    # ]\n",
407
    "    # true_labels = [\n",
408
    "    #     [entities_list[l] for (p, l) in zip(prediction, ground_true) if l != -100]\n",
409
    "    #     for prediction, label in zip(predictions, ground_true)\n",
410
    "    # ]\n",
411
    "    # print(predictions_labels[0])\n",
412
    "    # print(true_labels[0])\n",
413
    "\n",
414
    "    results = seqeval.compute(predictions=predictions_labels, references=true_labels)\n",
415
    "    return results"
416
   ]
417
  },
418
  {
419
   "cell_type": "code",
420
   "execution_count": null,
421
   "metadata": {},
422
   "outputs": [],
423
   "source": [
424
    "metric = load_metric(\"seqeval\")"
425
   ]
426
  },
427
  {
428
   "cell_type": "code",
429
   "execution_count": null,
430
   "metadata": {},
431
   "outputs": [],
432
   "source": [
433
    "# evaluate the model\n",
434
    "results = metric.compute(predictions=annotations_entities, references=dataset['test']['ner_tags'])"
435
   ]
436
  }
437
 ],
438
 "metadata": {
439
  "kernelspec": {
440
   "display_name": "TER",
441
   "language": "python",
442
   "name": "python3"
443
  },
444
  "language_info": {
445
   "codemirror_mode": {
446
    "name": "ipython",
447
    "version": 3
448
   },
449
   "file_extension": ".py",
450
   "mimetype": "text/x-python",
451
   "name": "python",
452
   "nbconvert_exporter": "python",
453
   "pygments_lexer": "ipython3",
454
   "version": "3.10.13"
455
  }
456
 },
457
 "nbformat": 4,
458
 "nbformat_minor": 2
459
}