Diff of /evaluation.py [000000] .. [9063a2]

Switch to unified view

a b/evaluation.py
1
from flair.data import Sentence
2
from flair.models import SequenceTagger, TextClassifier
3
from flair.tokenization import SciSpacyTokenizer
4
from transformers import pipeline, TextClassificationPipeline, AutoTokenizer, TFBertForTokenClassification, BertForSequenceClassification, AutoModelForSequenceClassification
5
from transformers.trainer import Trainer, TrainingArguments
6
from stqdm import stqdm
7
from allennlp.predictors.predictor import Predictor
8
import os
9
import fitz
10
import streamlit
11
import wikipedia
12
import nltk.data
13
import os
14
15
16
os.environ["TOKENIZERS_PARALLELISM"] = "false"
17
18
19
class InferenceADE:  
20
    ''' Voting classifier using 3 different models for ADE detection '''
21
22
    def __init__(self, pipeline_scibert, pipeline_biolink, model_hunflair):
23
        self.f1_score_biolink = 0.96   # not real f1 scores for now
24
        self.f1_score_scibert = 0.80
25
        self.f1_score_hunflair = 0.90
26
        self.pipeline_scibert = pipeline_scibert
27
        self.pipeline_biolink = pipeline_biolink
28
        self.model_hunflair = model_hunflair
29
30
    def __call__(self, sentence):
31
32
        result_bert = self.pipeline_scibert(sentence)[0]
33
        result_biolink = self.pipeline_biolink(sentence)[0]
34
        s = Sentence(sentence)
35
        self.model_hunflair.predict(s)
36
        result_hunflair = s.labels[0].to_dict()
37
38
        if result_bert['label'] == 'LABEL_0':
39
            pred_scibert = [result_bert['score'], 1-result_bert['score']]
40
        elif result_bert['label'] == 'LABEL_1':
41
            pred_scibert = [1-result_bert['score'], result_bert['score']]
42
43
        if result_biolink['label'] == 'LABEL_0':
44
            pred_biolink = [result_biolink['score'], 1-result_biolink['score']]
45
        elif result_biolink['label'] == 'LABEL_1':
46
            pred_biolink = [1-result_biolink['score'], result_biolink['score']]
47
48
        if result_hunflair['value'] == '0':
49
            pred_hunflair = [result_hunflair['confidence'], 1-result_hunflair['confidence']]
50
        elif result_hunflair['value'] == '1':
51
            pred_hunflair = [1-result_hunflair['confidence'], result_hunflair['confidence']]
52
53
        # voting classifier
54
55
        weighted_average_1 = float((self.f1_score_biolink * pred_biolink[0] + self.f1_score_scibert * pred_scibert[0] + self.f1_score_hunflair * pred_hunflair[0]) / (self.f1_score_biolink + self.f1_score_scibert + self.f1_score_hunflair))
56
        weighted_average_2 = float((self.f1_score_biolink * pred_biolink[1] + self.f1_score_scibert * pred_scibert[1] + self.f1_score_hunflair * pred_hunflair[1]) / (self.f1_score_biolink + self.f1_score_scibert + self.f1_score_hunflair))
57
58
        return [weighted_average_1, weighted_average_2]
59
60
61
62
def extraction(filename: str, choices: list[bool], use_streamlit: bool = True):
63
    ''' Takes as input the name of a pdf file and extract the wanted entities from it 
64
        Outputs a dictionary with the entities' names as keys and the list of entities as values
65
    '''
66
    tokenizer_split_sentences = nltk.data.load('tokenizers/punkt/english.pickle')
67
    root = './NER-Medical-Document/processed_files/' + filename[:-4] + '/'
68
    results = {}
69
    models = {}
70
    limit = 20 # maximum number of files to process, used for testing if the pdf file is too big
71
    ind = 0
72
    # only add the necessary models to the dictionary 'models' (avoid unnecessary loading)
73
    model_url = 'https://storage.googleapis.com/allennlp-public-models/coref-spanbert-large-2021.03.10.tar.gz'
74
    predictor = Predictor.from_path(model_url)
75
76
    if choices[0]:
77
        tagger_chemicals: SequenceTagger = SequenceTagger.load('./NER-Medical-Document/training_results/best-model.pt')
78
        models[0] = tagger_chemicals
79
80
    if choices[1]:
81
        tagger_diseases: SequenceTagger = SequenceTagger.load('hunflair-disease')
82
        models[1] = tagger_diseases
83
84
    if choices[2]:
85
        tagger_dates = SequenceTagger.load("flair/ner-english-ontonotes-fast")
86
        models[2] = tagger_dates
87
88
    if choices[3]:
89
        # 3 different methods for ADE detection
90
        method = 2
91
92
        if method == 1:
93
            # Use a token classification model (classify a token as an ADE, not a sentence)
94
            model_adverse_name = "abhibisht89/spanbert-large-cased-finetuned-ade_corpus_v2" # model name from huggingface.co/models
95
            model_adverse = TFBertForTokenClassification.from_pretrained(model_adverse_name, from_pt=True)
96
            tokenizer_adverse = AutoTokenizer.from_pretrained(model_adverse_name)
97
            models[3] = pipeline("token-classification", model = model_adverse, tokenizer = tokenizer_adverse, grouped_entities=True)
98
99
        elif method == 2:
100
            # Sentence classification: use HunFlair model + negation detection
101
            tokenizer_neg = AutoTokenizer.from_pretrained("bvanaken/clinical-assertion-negation-bert")
102
            model_neg = AutoModelForSequenceClassification.from_pretrained("bvanaken/clinical-assertion-negation-bert")
103
            pipeline_neg = TextClassificationPipeline(model=model_neg, tokenizer=tokenizer_neg)
104
            model_hunflair = TextClassifier.load('./NER-Medical-Document/training_results/flair_bert/best-model.pt')
105
            models[3] = model_hunflair
106
107
        elif method == 3:
108
            # Sentence classification: use the InferenceADE class to design a voting classifier
109
            model_scibert_name = 'NER-Medical-Document/training_results/scibert_scivocab_uncased'
110
            model_scibert = BertForSequenceClassification.from_pretrained(model_scibert_name)
111
            tokenizer_scibert = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
112
            pipeline_scibert = pipeline("text-classification", model = model_scibert, tokenizer = tokenizer_scibert)
113
114
            model_biolink_name = 'NER_Medical-Document/training_results/BioLinkBERT-base'
115
            model_biolink = BertForSequenceClassification.from_pretrained(model_biolink_name)
116
            tokenizer_biolink = AutoTokenizer.from_pretrained('michiyasunaga/BioLinkBERT-base')
117
            pipeline_biolink = pipeline("text-classification", model = model_biolink, tokenizer = tokenizer_biolink)
118
119
            model_hunflair = TextClassifier.load('./NER-Medical-Document/training_results/flair_bert/best-model.pt')
120
121
            models[3] = InferenceADE(pipeline_scibert, pipeline_biolink, model_hunflair)
122
123
    if choices[4]:
124
        tagger_doses = SequenceTagger.load("flair/ner-english-ontonotes-fast")
125
        models[4] = tagger_doses
126
127
128
    dic = {0: 'Chemicals', 1: 'Diseases', 2: 'Dates', 3: 'Adverse effects', 4: 'Doses'}
129
130
    info = 'Extracting '
131
    for j, c in enumerate(choices):
132
        if c:
133
            info += dic[j].lower() + ' and '
134
    streamlit.write(info[:-5] + ' entities...')
135
136
    local_results_chemicals = []
137
    local_results_diseases = []
138
    local_results_dates = []
139
    local_results_adverse = []
140
    local_results_doses = []
141
    dic_doses_chemicals = {}
142
143
    for i, file in enumerate(os.listdir(root)):
144
        if file.endswith('.txt') and ind < limit:
145
            ind += 1
146
            with open(root+file, 'r') as f:
147
                paragraphs = f.read().split('\n\n')
148
                sentences = []
149
                for p in paragraphs:
150
                    sentences.extend(tokenizer_split_sentences.tokenize(p))
151
                    sentences[-1] = sentences[-1]+ '\n\n'                
152
153
154
                if not use_streamlit:
155
                    size_window = 3
156
                    list_coref = [sentences[i] for i in range(size_window)]
157
                    try:
158
                        prediction = predictor.predict(document=' '.join(list_coref))
159
                    except:
160
                        pass
161
162
                for s in stqdm(range(len(sentences))):
163
                    sentence_ = sentences[s]
164
165
                    ### Coreference resolution ###
166
                    if not use_streamlit:
167
                        if s >= size_window:
168
                            list_coref.append(sentence_)
169
                            list_coref.pop(0)
170
                            try:
171
                                prediction = predictor.predict(document=' '.join(list_coref))
172
                                transformed_chunk = predictor.coref_resolved(' '.join(list_coref))
173
                                paragraphs2 = transformed_chunk.split('\n\n')
174
                                for p in paragraphs2:
175
                                    sentences2 = tokenizer_split_sentences.tokenize(p)
176
                                sentence_transformed = sentences2[-1]
177
                            except:
178
                                sentence_transformed = sentence_
179
                                pass
180
                        else:
181
                            sentence_transformed = sentence_
182
183
                    sentence_ = sentence_.replace('\n', ' ')
184
                    if len(sentence_) >= 4:
185
                        print(sentence_)
186
                        sentence = Sentence(sentence_, use_tokenizer=SciSpacyTokenizer())
187
                        for j, c in enumerate(choices):
188
                            if c:
189
                                tagger = models[j]
190
                                if dic[j] == 'Adverse effects':
191
                                    if method == 1:
192
                                        result = tagger(sentence_)
193
                                        if result != []:
194
                                            for entity in tagger(sentence_):
195
                                                if entity['entity_group'] == 'ADR':
196
                                                    local_results_dates.append(entity['word'])
197
                                    elif method == 2:
198
                                        sentence = Sentence(sentence_, use_tokenizer=SciSpacyTokenizer())  # create new instance of Sentence
199
                                        tagger.predict(sentence)
200
                                        result = sentence.labels[0].to_dict()
201
                                        if result['value'] == '1':
202
                                            if not pipeline_neg(sentence_)[0]['label'] == 'ABSENT':
203
                                                print('DETECTED')
204
                                                local_results_adverse.append(sentence_)
205
                                    elif method == 3:
206
                                        result = tagger(sentence_)
207
                                        print(result)
208
                                        if result[1] > 0.5:
209
                                            local_results_adverse.append(sentence_)
210
                                        else:
211
                                            models[0].predict(sentence)
212
                                            found = False
213
                                            for annotation_layer in sentence.annotation_layers.keys():
214
                                                for entity in sentence.get_spans(annotation_layer):
215
                                                        found = True
216
                                                        sentence_2 = sentence_.replace(entity.text, 'aspirin')
217
                                            if found: 
218
                                                result = tagger(sentence_2)
219
                                                print(result)
220
                                                if result[1] > 0.5:
221
                                                    local_results_adverse.append(sentence_)
222
                                else: 
223
                                    tagger.predict(sentence)
224
                                    for annotation_layer in sentence.annotation_layers.keys():
225
                                        for entity in sentence.get_spans(annotation_layer):
226
                                            if dic[j] == 'Chemicals':
227
                                                local_results_chemicals.append(entity.text)
228
                                                detected_chemicals = True
229
                                                entity_chemical = entity.text
230
                                            elif dic[j] == 'Diseases':
231
                                                local_results_diseases.append(entity.text)
232
                                            elif dic[j] == 'Dates':
233
                                                if entity.tag == 'DATE':
234
                                                    local_results_dates.append(entity.text)
235
                                            elif dic[j] == 'Doses':
236
                                                if entity.tag == 'QUANTITY':
237
                                                    local_results_doses.append(entity.text)
238
                                                    if detected_chemicals:
239
                                                        print('YES')
240
                                                        print(detected_chemicals)
241
                                                        print(sentence)
242
                                                        dic_doses_chemicals[entity.text] = entity_chemical
243
                                                    else:
244
                                                        dic_doses_chemicals[entity.text] = 'unknown'
245
                        detected_chemicals = False
246
    for j, c in enumerate(choices):
247
        if c:
248
            if dic[j] == 'Chemicals':
249
                # next line is to avoid detecting some characters as 'drugs' (happened sometimes)
250
                local_results_chemicals = [x for x in local_results_chemicals if x not in ['(', ')', '[', ']', '{', '}', ' ', '']]
251
                results[dic[j]] = list(set(local_results_chemicals))
252
            elif dic[j] == 'Diseases':
253
                results[dic[j]] = list(set(local_results_diseases))
254
            elif dic[j] == 'Dates':
255
                results[dic[j]] = list(set(local_results_dates))
256
            elif dic[j] == 'Adverse effects':
257
                results[dic[j]] = list(set(local_results_adverse))
258
            elif dic[j] == 'Doses':
259
                results[dic[j]] = list(set(local_results_doses))
260
261
262
    streamlit.write('Done!')
263
    return results
264
265
266
267
def higlight(filename: str, choices: list[bool]):
268
    ''' Highlight the entities chosen in the pdf file whose name is 'filename' '''
269
270
    root = './NER-Medical-Document/processed_files/' + filename[:-4] + '/'
271
    pdf_input = fitz.open(root+filename)
272
    results = extraction(filename, choices)
273
    streamlit.write('Highlighting entities...')
274
    text_instances = {}
275
276
    # go through all the pages of the chosen pdf file
277
    for page in pdf_input:
278
279
        # search all the occurences of the entities in the page
280
        for name, entities in results.items():
281
            text_instances[name] = [page.search_for(text) for text in entities]
282
283
        for name, instances in text_instances.items():
284
            add_definition = False
285
            if name == 'Chemicals':
286
                color = (1, 1, 0)
287
                add_definition = True
288
            elif name == 'Diseases':
289
                color = (0, 1, 0)
290
                add_definition = True
291
            elif name == 'Dates':
292
                color = (0, 0.7, 1)
293
                add_definition = False
294
            elif name == 'Adverse effects':
295
                color = (1, 0, 0)
296
                add_definition = False
297
            elif name == 'Doses':
298
                color = (1, 0, 1)
299
                add_definition = False
300
301
            # highlight each occurence of the entity in the page
302
            for i, inst in enumerate(instances):
303
                for x in inst:
304
                    # handle the case where an entity should not be highlighted (see README): the idea is too check the surrounding characters to 
305
                    # detect if the occurence of the entity is part of another word or not
306
307
                    # check the typical distance between 2 letters in the word (because it depends on the font size)
308
                    dist_letters = (x[2]-x[0])/len(results[name][i])
309
                    # draw a larger rectangle to check the surrounding characters
310
                    rect_larger = fitz.Rect(x[0]-dist_letters, x[1], x[2]+dist_letters, x[3])
311
                    non_accepted_chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
312
                    word = page.get_textbox(rect_larger).lower()
313
                    for sub in results[name][i].split():
314
                        word = word.replace(sub.lower(), '')
315
                    keep = True
316
                    for l in word:
317
                        if l in non_accepted_chars:
318
                            keep = False
319
                    if not keep:
320
                        continue     # ignore the occurence of the entity if it is part of another word
321
322
                    annot = page.add_highlight_annot(x)
323
                    annot.set_colors({"stroke": color})
324
                    annot.set_opacity(0.4)
325
                    if add_definition:
326
                        try:
327
                            annot.set_popup(x)
328
                            info = annot.info
329
                            info["title"] = "Definition"
330
                            if name == 'Chemicals':
331
                                info["content"] = wikipedia.summary(results[name][i] +  ' (drug)').split('.')[0]
332
                            else:
333
                                info["content"] = wikipedia.summary(results[name][i] + f' ({name.lower()[:-1]})').split('.')[0]
334
                            annot.set_info(info)
335
                        except:
336
                            pass
337
                    annot.update()
338
                
339
    if os.path.exists(root+filename[:-4]+'_highlighted.pdf'):
340
            os.remove(root+filename[:-4]+'_highlighted.pdf')
341
    pdf_input.save(root+filename[:-4]+'_highlighted.pdf')
342
    streamlit.write('Done!')
343
344
345
346
347
348
if __name__ == '__main__':
349
    print(extraction('0.txt', [True, True, True, True, True], use_streamlit=False))