**Evaluation of the Zero-shot approach for NER**

In [641]:
# imports
import os
import torch
from transformers import AutoTokenizer
from datasets import Dataset, DatasetDict, load_metric
import pandas as pd
import evaluate

In [642]:
# paths
root_path = ".."
# root_path = "./drive/MyDrive/HandsOnNLP" # for google colab
data_path = f'{root_path}/data'
annotations_path = f'{data_path}/Annotations_Mistral_Prompt_2'
chia_bio_path = f'{data_path}/chia/chia_bio'

In [643]:
ann_files = os.listdir(annotations_path)
len(ann_files)

200

In [644]:
true_ann = {} # list with real annotations
mistral_ann = {} # list with mistral annotations

In [645]:
# dict for the entities (entity to int value)
simple_ent = {"Condition", "Value", "Drug", "Procedure", "Measurement", "Temporal", "Observation", "Person", "Mood", "Pregnancy_considerations", "Device"}
sel_ent = {
    "O": 0,
    "B-Condition": 1,
    "I-Condition": 2,
    "B-Value": 3,
    "I-Value": 4,
    "B-Drug": 5,
    "I-Drug": 6,
    "B-Procedure": 7,
    "I-Procedure": 8,
    "B-Measurement": 9,
    "I-Measurement": 10,
    "B-Temporal": 11,
    "I-Temporal": 12,
    "B-Observation": 13,
    "I-Observation": 14,
    "B-Person": 15,
    "I-Person": 16,
    "B-Mood": 17,
    "I-Mood": 18,
    "B-Pregnancy_considerations": 19,
    "I-Pregnancy_considerations": 20,
    "B-Device": 21,
    "I-Device": 22
}

entities_list = list(sel_ent.keys())
sel_ent_inv = {v: k for k, v in sel_ent.items()}

In [646]:
import re

In [647]:
def parse_ann2bio(sentence, pattern, pattern1, pattern2):
    if sentence[-1] == "\n":
        sentence = sentence[:-2] # remove the \n and a final point wrongly added
    else:
        sentence = sentence[:-1] # remove the final point wrongly added
    
    # find the entities
    occurrences = re.finditer(pattern, sentence)
    indexes = [(match.start(), match.end()) for match in occurrences]

    annotation = []
    i = 0
    # create the bio list
    for beg, end in indexes:
        if beg > i:
            annotation.extend([(word, "O") for word in sentence[i:beg].split()])
        entity = sentence[beg:end]
        entity_name = re.search(pattern1, entity).group(1)
        entity = entity.replace(f'<{entity_name}>', "").replace(f'</{entity_name}>', "")
        split_entity = entity.split()
        annotation.append((split_entity[0], "B-" + entity_name))
        annotation.extend([(word, "I-" + entity_name) for word in split_entity[1:]])
        i = end
    annotation.extend([(word, "O") for word in sentence[i:].split()])

    # check punctuation sign in tokens and put them as individual tokens
    ps = r'(\.|\,|\:|\;|\!|\?|\-|\(|\)|\[|\]|\{|\}|\")'
    new_annotation = []
    for i,(word, tag) in enumerate(annotation):
        if re.search(ps, word):
            # find the ocurrences of the punctuation signs
            occurrences = re.finditer(ps, word)
            indexes = [(match.start(), match.end()) for match in occurrences]
            # create the new tokens
            last = 0
            for j, (beg, end) in enumerate(indexes):
                if beg > last:
                    new_annotation.append((word[last:beg], tag))
                if tag != "O":
                    label = f'I-{tag.split("-")[1]}'
                else:
                    label = "O"
                if end < len(word) or (i < len(annotation) - 1 and annotation[i+1][1] == label):
                    new_annotation.append((word[beg:end], label))
                else:
                    new_annotation.append((word[beg:end], 'O')) 
                last = end
            if last < len(word):
                new_annotation.append((word[last:], label))   
                
        else:
            new_annotation.append((word, tag))

    
    return new_annotation

In [648]:
pattern1 = r'<(Person|Condition|Value|Drug|Procedure|Measurement|Temporal|Observation|Mood|Pregnancy_considerations|Device)>'
pattern2 = r'</(Person|Condition|Value|Drug|Procedure|Measurement|Temporal|Observation|Mood|Pregnancy_considerations|Device)>'
pattern = f'{pattern1}.*?{pattern2}'

In [649]:
# get BIO annotations for mistral outputs
for file in ann_files:
    mistral_ann[file] = []
    with open(f"{annotations_path}/{file}", "r") as f:
        sentences = [line for line in f.readlines() if line != "\n" and line != " \n" and line != '']

    for sentence in sentences:
        mistral_ann[file].append(parse_ann2bio(sentence, pattern, pattern1, pattern2))
len(mistral_ann)

200

In [650]:
sent = "Severely to isolate for <Procedure>procedure</Procedure>."
parse_ann2bio(sent, pattern, pattern1, pattern2)

[('Severely', 'O'),
 ('to', 'O'),
 ('isolate', 'O'),
 ('for', 'O'),
 ('procedure', 'B-Procedure')]

In [651]:
# read real annotations from chia_bio
for file in ann_files:
    true_ann[file] = []
    with open(f"{chia_bio_path}/{file}", "r") as fd:
        sentences_ann = fd.read().split("\n\n")
    sentences_ann = [sentence for sentence in sentences_ann if sentence != "" and sentence != '\n']
    for sentence in sentences_ann:
        true_ann[file].append(sentence)
len(true_ann)

200

In [652]:
i = 0
corrupted_files = []
for file in ann_files:
    if len(true_ann[file]) != len(mistral_ann[file]):
        i += 1
        print(f"Error in file {file}")
        print(f"True: {len(true_ann[file])}, Mistral: {len(mistral_ann[file])}")
        corrupted_files.append(file)
print(i/len(ann_files))

Error in file NCT03132259_exc.bio.txt
True: 12, Mistral: 0
0.005


In [653]:
# remove corructed file
for file in corrupted_files:
    del true_ann[file]
    del mistral_ann[file]
    ann_files.remove(file)

In [654]:
ann_files[0]

'NCT02322203_inc.bio.txt'

In [655]:
true_ann_aux = {}

for file in ann_files:
    true_ann_aux[file] = []
    for i in range(len(true_ann[file])):
        annotation = []
        lines = true_ann[file][i].split("\n")
        for line in lines:
            if line != "":
                spt_line = line.split()
                annotation.append((spt_line[0], spt_line[-1]))
        new_annotation = []
        ps = r'(\.|\,|\:|\;|\!|\?|\-|\(|\)|\[|\]|\{|\}|\")'
        for i,(word, tag) in enumerate(annotation):
            if re.search(ps, word):
                # find the ocurrences of the punctuation signs
                occurrences = re.finditer(ps, word)
                indexes = [(match.start(), match.end()) for match in occurrences]
                # create the new tokens
                last = 0
                for j, (beg, end) in enumerate(indexes):
                    if beg > last:
                        new_annotation.append((word[last:beg], tag))
                    if tag != "O":
                        label = f'I-{tag.split("-")[1]}'
                    else:
                        label = "O"
                    if end < len(word) or (i < len(annotation) - 1 and annotation[i+1][1] == label):
                        new_annotation.append((word[beg:end], label))
                    else:
                        new_annotation.append((word[beg:end], 'O')) 
                    last = end
                if last < len(word):
                    new_annotation.append((word[last:], label))
            else:
                new_annotation.append((word, tag))
        true_ann_aux[file].append(new_annotation)
true_ann = true_ann_aux
len(true_ann)

199

In [656]:
mistral_ann[ann_files[0]][1]

[('Subject', 'O'),
 ('understands', 'O'),
 ('the', 'O'),
 ('investigational', 'O'),
 ('nature', 'O'),
 ('of', 'O'),
 ('the', 'O'),
 ('study', 'O'),
 ('and', 'O'),
 ('provides', 'O'),
 ('written', 'O'),
 (',', 'O'),
 ('informed', 'B-Mood'),
 ('consent', 'I-Mood'),
 ('.', 'O')]

In [658]:
true_ann[ann_files[0]][1]

[('Subject', 'O'),
 ('understands', 'O'),
 ('the', 'O'),
 ('investigational', 'O'),
 ('nature', 'O'),
 ('of', 'O'),
 ('the', 'O'),
 ('study', 'O'),
 ('and', 'O'),
 ('provides', 'O'),
 ('written', 'O'),
 (',', 'O'),
 ('informed', 'O'),
 ('consent', 'O'),
 ('.', 'O')]

In [661]:
mistral_ann_dict = []

for file in ann_files:
    for i in range(len(mistral_ann[file])):
        dict_sent = {"tokens": [], "ner_tags": [], "file": file, "index": i}
        for word, tag in mistral_ann[file][i]:
            dict_sent["tokens"].append(word)
            # add the int representation of the entity
            dict_sent["ner_tags"].append(sel_ent[tag])
        mistral_ann_dict.append(dict_sent)
len(mistral_ann_dict)

1147

In [662]:
true_ann_dict = []

for file in ann_files:
    for i in range(len(true_ann[file])):
        dict_sent = {"tokens": [], "ner_tags": [], "file": file, "index": i}
        for word, tag in true_ann[file][i]:
            dict_sent["tokens"].append(word)
            # add the int representation of the entity
            dict_sent["ner_tags"].append(sel_ent[tag])
        true_ann_dict.append(dict_sent)
len(true_ann_dict)

1147

In [663]:
tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')

In [664]:
# tokenize and align the labels in the dataset
def tokenize_and_align_labels(sentence, flag = 'I'):
    """
    Tokenize the sentence and align the labels
    inputs:
        sentence: dict, the sentence from the dataset
        flag: str, the flag to indicate how to deal with the labels for subwords
            - 'I': use the label of the first subword for all subwords but as intermediate (I-ENT)
            - 'B': use the label of the first subword for all subwords as beginning (B-ENT)
            - None: use -100 for subwords
    outputs:
        tokenized_sentence: dict, the tokenized sentence now with a field for the labels
    """
    tokenized_sentence = tokenizer(sentence['tokens'], is_split_into_words=True, truncation=True)

    labels = []
    for i, labels_s in enumerate(sentence['ner_tags']):
        word_ids = tokenized_sentence.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # if the word_idx is None, assign -100
            if word_idx is None:
                label_ids.append(-100)
            # if it is a new word, assign the corresponding label
            elif word_idx != previous_word_idx:
                label_ids.append(labels_s[word_idx])
            # if it is the same word, check the flag to assign
            else:
                if flag == 'I':
                    if entities_list[labels_s[word_idx]].startswith('I'):
                      label_ids.append(labels_s[word_idx])
                    else:
                      label_ids.append(labels_s[word_idx] + 1)
                elif flag == 'B':
                    label_ids.append(labels_s[word_idx])
                elif flag == None:
                    label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)
    tokenized_sentence['labels'] = labels
    return tokenized_sentence

In [665]:
mis_df = pd.DataFrame(mistral_ann_dict)
true_df = pd.DataFrame(true_ann_dict)

In [666]:
mistral_ann_dataset = Dataset.from_pandas(mis_df)
true_ann_dataset = Dataset.from_pandas(true_df)

In [667]:
mistral_ann_dataset, true_ann_dataset

(Dataset({
     features: ['tokens', 'ner_tags', 'file', 'index'],
     num_rows: 1147
 }),
 Dataset({
     features: ['tokens', 'ner_tags', 'file', 'index'],
     num_rows: 1147
 }))

In [668]:
mistral_ann_dataset = mistral_ann_dataset.map(tokenize_and_align_labels, batched=True)
true_ann_dataset = true_ann_dataset.map(tokenize_and_align_labels, batched=True)

Map: 100%|██████████| 1147/1147 [00:00<00:00, 15088.51 examples/s]
Map: 100%|██████████| 1147/1147 [00:00<00:00, 18242.88 examples/s]


In [669]:
mistral_ann_dataset, true_ann_dataset

(Dataset({
     features: ['tokens', 'ner_tags', 'file', 'index', 'input_ids', 'attention_mask', 'labels'],
     num_rows: 1147
 }),
 Dataset({
     features: ['tokens', 'ner_tags', 'file', 'index', 'input_ids', 'attention_mask', 'labels'],
     num_rows: 1147
 }))

**Evaluation of the annotations made by Mistral using seqeval**

In [670]:
seqeval = evaluate.load('seqeval')

In [671]:
def compute_metrics(p):
    """
    Compute the metrics for the model
    inputs:
        p: tuple, the predictions and the ground true
    outputs:
        dict: the metrics
    """
    predictions, ground_true = p

    # Remove ignored index (special tokens)
    predictions_labels = []
    true_labels = []

    for preds, labels in zip(predictions, ground_true):
        preds_labels = []
        labels_true = []
        for pred, label in zip(preds, labels):
            if label != -100:
                if pred == -100:
                    pred = 0
                preds_labels.append(entities_list[pred])
                labels_true.append(entities_list[label])
        predictions_labels.append(preds_labels) 
        true_labels.append(labels_true)

    # predictions_labels = [
    #     [entities_list[p] for (p, l) in zip(prediction, ground_true) if l != -100]
    #     for prediction, label in zip(predictions, ground_true)
    # ]
    # true_labels = [
    #     [entities_list[l] for (p, l) in zip(prediction, ground_true) if l != -100]
    #     for prediction, label in zip(predictions, ground_true)
    # ]
    # print(predictions_labels[0])
    # print(true_labels[0])

    results = seqeval.compute(predictions=predictions_labels, references=true_labels)
    return results

In [672]:
print(mistral_ann_dataset['file'][14], mistral_ann_dataset['index'][14])
print(true_ann_dataset['file'][14])

NCT00806273_exc.bio.txt 1
NCT00806273_exc.bio.txt


In [680]:
print(mistral_ann_dataset['tokens'][1140])
print(true_ann_dataset['tokens'][1140])

['Men', 'and', 'women', '<Age>35', 'to', '70', 'years', 'of', 'age</Age>']
['Men', 'and', 'women', '35', 'to', '70', 'years', 'of', 'age']


In [674]:
for i in range(len(mistral_ann_dataset)):
    if len(mistral_ann_dataset['labels'][i]) != len(true_ann_dataset['labels'][i]):
        print(i)

14
15
24
33
37
57
71
72
90
92
101
102
109
110
111
127
142
155
162
169
176
188
194
201
205
209
210
226
229
230
231
232
233
241
244
251
271
273
274
307
310
320
324
332
337
339
349
350
363
369
372
383
403
404
408
412
417
421
432
447
448
463
464
465
470
475
487
491
500
512
513
530
536
548
567
573
578
581
596
600
611
623
631
633
646
654
655
656
671
681
686
694
707
708
719
725
727
731
742
745
756
757
759
763
767
769
777
780
782
783
788
789
794
795
800
801
805
806
809
816
817
820
822
823
828
831
837
855
869
873
877
903
912
924
943
945
959
966
972
974
977
980
981
983
994
998
999
1008
1010
1012
1013
1015
1016
1022
1026
1035
1047
1052
1060
1063
1066
1068
1076
1077
1082
1090
1093
1094
1099
1100
1111
1113
1119
1136
1140


In [675]:
compute_metrics((mistral_ann_dataset['labels'], true_ann_dataset['labels']))

  _warn_prf(average, modifier, msg_start, len(result))


{'Condition': {'precision': 0.5477680433875678,
  'recall': 0.663466397170288,
  'f1': 0.6000914076782451,
  'number': 3958},
 'Device': {'precision': 0.2777777777777778,
  'recall': 0.15151515151515152,
  'f1': 0.19607843137254904,
  'number': 33},
 'Drug': {'precision': 0.5060975609756098,
  'recall': 0.5684931506849316,
  'f1': 0.535483870967742,
  'number': 292},
 'Measurement': {'precision': 0.13314447592067988,
  'recall': 0.14968152866242038,
  'f1': 0.1409295352323838,
  'number': 314},
 'Mood': {'precision': 0.00684931506849315,
  'recall': 0.01694915254237288,
  'f1': 0.00975609756097561,
  'number': 59},
 'Observation': {'precision': 0.05454545454545454,
  'recall': 0.019736842105263157,
  'f1': 0.028985507246376812,
  'number': 152},
 'Person': {'precision': 0.08108108108108109,
  'recall': 0.05056179775280899,
  'f1': 0.06228373702422144,
  'number': 178},
 'Pregnancy_considerations': {'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'number': 19},
 'Procedure': {'precisi