In [None]:
# uncomment if working in colab
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# uncomment if using colab
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U datasets
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install seqeval
!pip install -q -U evaluate

In [None]:
import numpy as np
from transformers import AutoTokenizer, AutoModelForTokenClassification, DataCollatorForTokenClassification,  Trainer, TrainingArguments, AutoModelForTextGeneration
from datasets import load_dataset, load_metric
import evaluate
import torch

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [1]:
# paths
# root = '..'
root = './drive/MyDrive/TER-LISN-2024'
data_path = f'{root}/data'
model_path = f'{root}/models'

In [None]:
# dict for the entities (entity to int value)
simple_ent = {"Condition", "Value", "Drug", "Procedure", "Measurement", "Temporal", "Observation", "Person", "Device"}
sel_ent = {
    "O": 0,
    "B-Condition": 1,
    "I-Condition": 2,
    "B-Value": 3,
    "I-Value": 4,
    "B-Drug": 5,
    "I-Drug": 6,
    "B-Procedure": 7,
    "I-Procedure": 8,
    "B-Measurement": 9,
    "I-Measurement": 10,
    "B-Temporal": 11,
    "I-Temporal": 12,
    "B-Observation": 13,
    "I-Observation": 14,
    "B-Person": 15,
    "I-Person": 16,
    "B-Device": 17,
    "I-Device": 18
}

entities_list = list(sel_ent.keys())
sel_ent_inv = {v: k for k, v in sel_ent.items()}

In [None]:
class EnsembleModelNER:
    def __init__(self, ner_model_name, llm_name, ner_from_local=True, path_to_model = None, llm_from_local=False, device='cpu'):
        self.ner_model_name = ner_model_name
        self.llm_name = llm_name
        self.ner_from_local = ner_from_local
        self.llm_from_local = llm_from_local
        self.ner_tokenizer = AutoTokenizer.from_pretrained(self.ner_model_name)
        self.llm_tokenizer = AutoTokenizer.from_pretrained(self.llm_name)
        if self.ner_from_local:
            self.ner_model = torch.load(path_to_model)
        else:
            self.ner_model = AutoModelForTokenClassification.from_pretrained(self.ner_model_name)
        self.llm_model = AutoModelForTextGeneration.from_pretrained(self.llm_name)
        self.device = device
    
    # tokenize and align the labels in the dataset
    def _tokenize_and_align_labels(self, sentence, labels_s, flag = 'I'):
        """
        Tokenize the sentence and align the labels
        inputs:
            sentence: dict, the sentence from the dataset
            flag: str, the flag to indicate how to deal with the labels for subwords
                - 'I': use the label of the first subword for all subwords but as intermediate (I-ENT)
                - 'B': use the label of the first subword for all subwords as beginning (B-ENT)
                - None: use -100 for subwords
        outputs:
            tokenized_sentence: dict, the tokenized sentence now with a field for the labels
        """
        tokenized_sentence = tokenizer(sentence['tokens'], is_split_into_words=True, truncation=True)

        labels = []
        for i, labels_s in enumerate(sentence['ner_tags']):
            word_ids = tokenized_sentence.word_ids(batch_index=i)
            previous_word_idx = None
            label_ids = []
            for word_idx in word_ids:
                # if the word_idx is None, assign -100
                if word_idx is None:
                    label_ids.append(-100)
                # if it is a new word, assign the corresponding label
                elif word_idx != previous_word_idx:
                    label_ids.append(labels_s[word_idx])
                # if it is the same word, check the flag to assign
                else:
                    if flag == 'I':
                        if entities_list[labels_s[word_idx]].startswith('I'):
                          label_ids.append(labels_s[word_idx])
                        else:
                          label_ids.append(labels_s[word_idx] + 1)
                    elif flag == 'B':
                        label_ids.append(labels_s[word_idx])
                    elif flag == None:
                        label_ids.append(-100)
                previous_word_idx = word_idx
            labels.append(label_ids)
        tokenized_sentence['labels'] = labels
        return tokenized_sentence
    def annotate_sentences(dataset, labels, entities_list,criteria = 'first_label'):
        """
        Annotate the sentences with the predicted labels
        inputs:
            dataset: dataset, dataset with the sentences
            labels: list, list of labels
            entities_list: list, list of entities
            criteria: str, criteria to use to select the label when the words pices have different labels
                - first_label: select the first label
                - majority: select the label with the majority
        outputs:
            annotated_sentences: list, list of annotated sentences
        """
        annotated_sentences = []
        for i in range(len(dataset)):
            # get just the tokens different from None
            sentence = dataset[i]
            word_ids = sentence['word_ids']
            sentence_labels = labels[i]
            annotated_sentence = [[] for _ in range(len(dataset[i]['tokens']))]
            for word_id, label in zip(word_ids, sentence_labels):
                if word_id is not None:
                    annotated_sentence[word_id].append(label)
            annotated_sentence_filtered = []
            if criteria == 'first_label':
                annotated_sentence_filtered = [annotated_sentence[i][0] for i in range(len(annotated_sentence))]
            elif criteria == 'majority':
                annotated_sentence_filtered = [max(set(annotated_sentence[i]), key=annotated_sentence[i].count) for i in range(len(annotated_sentence))]

            annotated_sentences.append(annotated_sentence_filtered)
        return annotated_sentences

    def annotate_with_NER_model(self, dataset, entities_list):
        """
        Annotate the dataset with the NER model
        inputs:
            dataset: dataset, the dataset to annotate
            entities_list: list, the list of labels
        outputs:
            annotated_dataset: dataset, the annotated dataset
        """
        # tokenize and align the labels
        tokenized_dataset = dataset.map(lambda x: self._tokenize_and_align_labels(x, labels_s))
        # prepare the dataset for the model
        test_dataset = dataset['test']

        data_for_model = test_dataset.remove_columns(['file', 'tokens', 'word_ids'])

        data_loader = torch.utils.data.DataLoader(data_for_model, batch_size=16)
        
        self.ner_model.to(self.device)
        # predict the NER tags
        labels = []
        for batch in tqdm(data_loader):
        
            batch['input_ids'] = torch.LongTensor(np.column_stack(np.array(batch['input_ids']))).to(device)
            batch['attention_mask'] = torch.LongTensor(np.column_stack(np.array(batch['attention_mask']))).to(device)
            batch_tokenizer = {'input_ids': batch['input_ids'], 'attention_mask': batch['attention_mask']}
            # break
            with torch.no_grad():
                outputs = model(**batch_tokenizer)

            labels_batch = torch.argmax(outputs.logits, dim=2).to('cpu').numpy()
            labels.extend([list(labels_batch[i]) for i in range(labels_batch.shape[0])])

            del batch
            del outputs
            torch.cuda.empty_cache()

        # recover original annotations split by words
        annotated_dataset = annotate_sentences(test_dataset, labels, entities_list)
        self.ner_model.to('cpu')
        if self.device != 'cpu':
            torch.cuda.empty_cache()

        return annotated_dataset

    def generate_prompts(self, annotated_sentences, entities_list):
        """
        Generate the prompts for the LLM model
        inputs:
            annotated_sentences: list, the list of annotated sentences
        outputs:
            prompts: list, the list of prompts
        """
        prompt_main = f"""I am working in a named entity recognition task for Clinical trial
        eligibility criteria. I have annotated a sentence with the NER model and I would like to
        check if the annotations are correct. The list of possible entities is {','.join(entities_list)}.
        Please keep the same BIO-format for annotations and do not change the words, just
        check the labels annontations. The sentence you must check is:\n\n"""

        prompts = [prompt_main + '\n'.join(annotated_sentences[i]) for i in range(len(annotated_sentences))]
        return prompts

    def annotate_with_LLM_model(self, dataset, entities_list):
        """
        Annotate the dataset with the LLM model
        inputs:
            dataset: dataset, the dataset to annotate
        outputs:
            annotated_dataset: dataset, the annotated dataset
        """
        self.llm_model.to(self.device)

        prompts = generate_prompts(dataset, entities_list)

        # predict the NER tags
        llm_annotations = []
        for prompt in prompts:
            inputs = llm_tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device)
            with torch.no_grad():
                outputs = llm_model.generate(**inputs)
            llm_annotations.append(llm_tokenizer.decode(outputs[0], skip_special_tokens=True))
        llm_model.to('cpu')
        if self.device != 'cpu':
            torch.cuda.empty_cache()
        return llm_annotations

    def predict(self, dataset, entities_list):
        """
        Predict the NER tags for the dataset
        inputs:
            dataset: dataset, the dataset to predict
            entities_list: list, the list of entities
        outputs:
            predictions: list, the list of predictions
        """

        # first step: annotate sentences with the NER model
        annotated_dataset = self.annotate_with_NER_model(dataset, entities_list)

        # second step: use the annotated sentences as input for the LLM model to try
        # to improve the annotations
        annotated_sentences_after_llm = self.annotate_with_LLM_model(annotated_dataset, entities_list)

        return annotated_sentences_after_llm


In [None]:
ner_model_name = 'roberta-base'
llm_name = 'BioMistral/BioMistral-7B'
ner_from_local = True
local_path = f'{model_path}/roberta-chia-ner.pt'

# load the dataset
dataset = load_dataset('JavierLopetegui/chia_v1')

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# load ensemble model
ensemble_model = EnsembleModelNER(ner_model_name, llm_name, ner_from_local, local_path, device=device)

In [None]:
annotations = ensemble_model.predict(dataset, entities_list)

In [None]:
annotations[0]

In [None]:
annotations_entities = []

for annotation in annotations:
    annotations_entities.append([int(a.split()[1]) for a in annotation])

In [None]:
def compute_metrics(p):
    """
    Compute the metrics for the model
    inputs:
        p: tuple, the predictions and the ground true
    outputs:
        dict: the metrics
    """
    predictions, ground_true = p

    # Remove ignored index (special tokens)
    predictions_labels = []
    true_labels = []

    for preds, labels in zip(predictions, ground_true):
        preds_labels = []
        labels_true = []
        for pred, label in zip(preds, labels):
            if label != -100:
                if pred == -100:
                    pred = 0
                preds_labels.append(entities_list[pred])
                labels_true.append(entities_list[label])
        predictions_labels.append(preds_labels) 
        true_labels.append(labels_true)

    # predictions_labels = [
    #     [entities_list[p] for (p, l) in zip(prediction, ground_true) if l != -100]
    #     for prediction, label in zip(predictions, ground_true)
    # ]
    # true_labels = [
    #     [entities_list[l] for (p, l) in zip(prediction, ground_true) if l != -100]
    #     for prediction, label in zip(predictions, ground_true)
    # ]
    # print(predictions_labels[0])
    # print(true_labels[0])

    results = seqeval.compute(predictions=predictions_labels, references=true_labels)
    return results

In [None]:
metric = load_metric("seqeval")

In [None]:
# evaluate the model
results = metric.compute(predictions=annotations_entities, references=dataset['test']['ner_tags'])