|
a |
|
b/evaluation.py |
|
|
1 |
from flair.data import Sentence |
|
|
2 |
from flair.models import SequenceTagger, TextClassifier |
|
|
3 |
from flair.tokenization import SciSpacyTokenizer |
|
|
4 |
from transformers import pipeline, TextClassificationPipeline, AutoTokenizer, TFBertForTokenClassification, BertForSequenceClassification, AutoModelForSequenceClassification |
|
|
5 |
from transformers.trainer import Trainer, TrainingArguments |
|
|
6 |
from stqdm import stqdm |
|
|
7 |
from allennlp.predictors.predictor import Predictor |
|
|
8 |
import os |
|
|
9 |
import fitz |
|
|
10 |
import streamlit |
|
|
11 |
import wikipedia |
|
|
12 |
import nltk.data |
|
|
13 |
import os |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
class InferenceADE: |
|
|
20 |
''' Voting classifier using 3 different models for ADE detection ''' |
|
|
21 |
|
|
|
22 |
def __init__(self, pipeline_scibert, pipeline_biolink, model_hunflair): |
|
|
23 |
self.f1_score_biolink = 0.96 # not real f1 scores for now |
|
|
24 |
self.f1_score_scibert = 0.80 |
|
|
25 |
self.f1_score_hunflair = 0.90 |
|
|
26 |
self.pipeline_scibert = pipeline_scibert |
|
|
27 |
self.pipeline_biolink = pipeline_biolink |
|
|
28 |
self.model_hunflair = model_hunflair |
|
|
29 |
|
|
|
30 |
def __call__(self, sentence): |
|
|
31 |
|
|
|
32 |
result_bert = self.pipeline_scibert(sentence)[0] |
|
|
33 |
result_biolink = self.pipeline_biolink(sentence)[0] |
|
|
34 |
s = Sentence(sentence) |
|
|
35 |
self.model_hunflair.predict(s) |
|
|
36 |
result_hunflair = s.labels[0].to_dict() |
|
|
37 |
|
|
|
38 |
if result_bert['label'] == 'LABEL_0': |
|
|
39 |
pred_scibert = [result_bert['score'], 1-result_bert['score']] |
|
|
40 |
elif result_bert['label'] == 'LABEL_1': |
|
|
41 |
pred_scibert = [1-result_bert['score'], result_bert['score']] |
|
|
42 |
|
|
|
43 |
if result_biolink['label'] == 'LABEL_0': |
|
|
44 |
pred_biolink = [result_biolink['score'], 1-result_biolink['score']] |
|
|
45 |
elif result_biolink['label'] == 'LABEL_1': |
|
|
46 |
pred_biolink = [1-result_biolink['score'], result_biolink['score']] |
|
|
47 |
|
|
|
48 |
if result_hunflair['value'] == '0': |
|
|
49 |
pred_hunflair = [result_hunflair['confidence'], 1-result_hunflair['confidence']] |
|
|
50 |
elif result_hunflair['value'] == '1': |
|
|
51 |
pred_hunflair = [1-result_hunflair['confidence'], result_hunflair['confidence']] |
|
|
52 |
|
|
|
53 |
# voting classifier |
|
|
54 |
|
|
|
55 |
weighted_average_1 = float((self.f1_score_biolink * pred_biolink[0] + self.f1_score_scibert * pred_scibert[0] + self.f1_score_hunflair * pred_hunflair[0]) / (self.f1_score_biolink + self.f1_score_scibert + self.f1_score_hunflair)) |
|
|
56 |
weighted_average_2 = float((self.f1_score_biolink * pred_biolink[1] + self.f1_score_scibert * pred_scibert[1] + self.f1_score_hunflair * pred_hunflair[1]) / (self.f1_score_biolink + self.f1_score_scibert + self.f1_score_hunflair)) |
|
|
57 |
|
|
|
58 |
return [weighted_average_1, weighted_average_2] |
|
|
59 |
|
|
|
60 |
|
|
|
61 |
|
|
|
62 |
def extraction(filename: str, choices: list[bool], use_streamlit: bool = True): |
|
|
63 |
''' Takes as input the name of a pdf file and extract the wanted entities from it |
|
|
64 |
Outputs a dictionary with the entities' names as keys and the list of entities as values |
|
|
65 |
''' |
|
|
66 |
tokenizer_split_sentences = nltk.data.load('tokenizers/punkt/english.pickle') |
|
|
67 |
root = './NER-Medical-Document/processed_files/' + filename[:-4] + '/' |
|
|
68 |
results = {} |
|
|
69 |
models = {} |
|
|
70 |
limit = 20 # maximum number of files to process, used for testing if the pdf file is too big |
|
|
71 |
ind = 0 |
|
|
72 |
# only add the necessary models to the dictionary 'models' (avoid unnecessary loading) |
|
|
73 |
model_url = 'https://storage.googleapis.com/allennlp-public-models/coref-spanbert-large-2021.03.10.tar.gz' |
|
|
74 |
predictor = Predictor.from_path(model_url) |
|
|
75 |
|
|
|
76 |
if choices[0]: |
|
|
77 |
tagger_chemicals: SequenceTagger = SequenceTagger.load('./NER-Medical-Document/training_results/best-model.pt') |
|
|
78 |
models[0] = tagger_chemicals |
|
|
79 |
|
|
|
80 |
if choices[1]: |
|
|
81 |
tagger_diseases: SequenceTagger = SequenceTagger.load('hunflair-disease') |
|
|
82 |
models[1] = tagger_diseases |
|
|
83 |
|
|
|
84 |
if choices[2]: |
|
|
85 |
tagger_dates = SequenceTagger.load("flair/ner-english-ontonotes-fast") |
|
|
86 |
models[2] = tagger_dates |
|
|
87 |
|
|
|
88 |
if choices[3]: |
|
|
89 |
# 3 different methods for ADE detection |
|
|
90 |
method = 2 |
|
|
91 |
|
|
|
92 |
if method == 1: |
|
|
93 |
# Use a token classification model (classify a token as an ADE, not a sentence) |
|
|
94 |
model_adverse_name = "abhibisht89/spanbert-large-cased-finetuned-ade_corpus_v2" # model name from huggingface.co/models |
|
|
95 |
model_adverse = TFBertForTokenClassification.from_pretrained(model_adverse_name, from_pt=True) |
|
|
96 |
tokenizer_adverse = AutoTokenizer.from_pretrained(model_adverse_name) |
|
|
97 |
models[3] = pipeline("token-classification", model = model_adverse, tokenizer = tokenizer_adverse, grouped_entities=True) |
|
|
98 |
|
|
|
99 |
elif method == 2: |
|
|
100 |
# Sentence classification: use HunFlair model + negation detection |
|
|
101 |
tokenizer_neg = AutoTokenizer.from_pretrained("bvanaken/clinical-assertion-negation-bert") |
|
|
102 |
model_neg = AutoModelForSequenceClassification.from_pretrained("bvanaken/clinical-assertion-negation-bert") |
|
|
103 |
pipeline_neg = TextClassificationPipeline(model=model_neg, tokenizer=tokenizer_neg) |
|
|
104 |
model_hunflair = TextClassifier.load('./NER-Medical-Document/training_results/flair_bert/best-model.pt') |
|
|
105 |
models[3] = model_hunflair |
|
|
106 |
|
|
|
107 |
elif method == 3: |
|
|
108 |
# Sentence classification: use the InferenceADE class to design a voting classifier |
|
|
109 |
model_scibert_name = 'NER-Medical-Document/training_results/scibert_scivocab_uncased' |
|
|
110 |
model_scibert = BertForSequenceClassification.from_pretrained(model_scibert_name) |
|
|
111 |
tokenizer_scibert = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased') |
|
|
112 |
pipeline_scibert = pipeline("text-classification", model = model_scibert, tokenizer = tokenizer_scibert) |
|
|
113 |
|
|
|
114 |
model_biolink_name = 'NER_Medical-Document/training_results/BioLinkBERT-base' |
|
|
115 |
model_biolink = BertForSequenceClassification.from_pretrained(model_biolink_name) |
|
|
116 |
tokenizer_biolink = AutoTokenizer.from_pretrained('michiyasunaga/BioLinkBERT-base') |
|
|
117 |
pipeline_biolink = pipeline("text-classification", model = model_biolink, tokenizer = tokenizer_biolink) |
|
|
118 |
|
|
|
119 |
model_hunflair = TextClassifier.load('./NER-Medical-Document/training_results/flair_bert/best-model.pt') |
|
|
120 |
|
|
|
121 |
models[3] = InferenceADE(pipeline_scibert, pipeline_biolink, model_hunflair) |
|
|
122 |
|
|
|
123 |
if choices[4]: |
|
|
124 |
tagger_doses = SequenceTagger.load("flair/ner-english-ontonotes-fast") |
|
|
125 |
models[4] = tagger_doses |
|
|
126 |
|
|
|
127 |
|
|
|
128 |
dic = {0: 'Chemicals', 1: 'Diseases', 2: 'Dates', 3: 'Adverse effects', 4: 'Doses'} |
|
|
129 |
|
|
|
130 |
info = 'Extracting ' |
|
|
131 |
for j, c in enumerate(choices): |
|
|
132 |
if c: |
|
|
133 |
info += dic[j].lower() + ' and ' |
|
|
134 |
streamlit.write(info[:-5] + ' entities...') |
|
|
135 |
|
|
|
136 |
local_results_chemicals = [] |
|
|
137 |
local_results_diseases = [] |
|
|
138 |
local_results_dates = [] |
|
|
139 |
local_results_adverse = [] |
|
|
140 |
local_results_doses = [] |
|
|
141 |
dic_doses_chemicals = {} |
|
|
142 |
|
|
|
143 |
for i, file in enumerate(os.listdir(root)): |
|
|
144 |
if file.endswith('.txt') and ind < limit: |
|
|
145 |
ind += 1 |
|
|
146 |
with open(root+file, 'r') as f: |
|
|
147 |
paragraphs = f.read().split('\n\n') |
|
|
148 |
sentences = [] |
|
|
149 |
for p in paragraphs: |
|
|
150 |
sentences.extend(tokenizer_split_sentences.tokenize(p)) |
|
|
151 |
sentences[-1] = sentences[-1]+ '\n\n' |
|
|
152 |
|
|
|
153 |
|
|
|
154 |
if not use_streamlit: |
|
|
155 |
size_window = 3 |
|
|
156 |
list_coref = [sentences[i] for i in range(size_window)] |
|
|
157 |
try: |
|
|
158 |
prediction = predictor.predict(document=' '.join(list_coref)) |
|
|
159 |
except: |
|
|
160 |
pass |
|
|
161 |
|
|
|
162 |
for s in stqdm(range(len(sentences))): |
|
|
163 |
sentence_ = sentences[s] |
|
|
164 |
|
|
|
165 |
### Coreference resolution ### |
|
|
166 |
if not use_streamlit: |
|
|
167 |
if s >= size_window: |
|
|
168 |
list_coref.append(sentence_) |
|
|
169 |
list_coref.pop(0) |
|
|
170 |
try: |
|
|
171 |
prediction = predictor.predict(document=' '.join(list_coref)) |
|
|
172 |
transformed_chunk = predictor.coref_resolved(' '.join(list_coref)) |
|
|
173 |
paragraphs2 = transformed_chunk.split('\n\n') |
|
|
174 |
for p in paragraphs2: |
|
|
175 |
sentences2 = tokenizer_split_sentences.tokenize(p) |
|
|
176 |
sentence_transformed = sentences2[-1] |
|
|
177 |
except: |
|
|
178 |
sentence_transformed = sentence_ |
|
|
179 |
pass |
|
|
180 |
else: |
|
|
181 |
sentence_transformed = sentence_ |
|
|
182 |
|
|
|
183 |
sentence_ = sentence_.replace('\n', ' ') |
|
|
184 |
if len(sentence_) >= 4: |
|
|
185 |
print(sentence_) |
|
|
186 |
sentence = Sentence(sentence_, use_tokenizer=SciSpacyTokenizer()) |
|
|
187 |
for j, c in enumerate(choices): |
|
|
188 |
if c: |
|
|
189 |
tagger = models[j] |
|
|
190 |
if dic[j] == 'Adverse effects': |
|
|
191 |
if method == 1: |
|
|
192 |
result = tagger(sentence_) |
|
|
193 |
if result != []: |
|
|
194 |
for entity in tagger(sentence_): |
|
|
195 |
if entity['entity_group'] == 'ADR': |
|
|
196 |
local_results_dates.append(entity['word']) |
|
|
197 |
elif method == 2: |
|
|
198 |
sentence = Sentence(sentence_, use_tokenizer=SciSpacyTokenizer()) # create new instance of Sentence |
|
|
199 |
tagger.predict(sentence) |
|
|
200 |
result = sentence.labels[0].to_dict() |
|
|
201 |
if result['value'] == '1': |
|
|
202 |
if not pipeline_neg(sentence_)[0]['label'] == 'ABSENT': |
|
|
203 |
print('DETECTED') |
|
|
204 |
local_results_adverse.append(sentence_) |
|
|
205 |
elif method == 3: |
|
|
206 |
result = tagger(sentence_) |
|
|
207 |
print(result) |
|
|
208 |
if result[1] > 0.5: |
|
|
209 |
local_results_adverse.append(sentence_) |
|
|
210 |
else: |
|
|
211 |
models[0].predict(sentence) |
|
|
212 |
found = False |
|
|
213 |
for annotation_layer in sentence.annotation_layers.keys(): |
|
|
214 |
for entity in sentence.get_spans(annotation_layer): |
|
|
215 |
found = True |
|
|
216 |
sentence_2 = sentence_.replace(entity.text, 'aspirin') |
|
|
217 |
if found: |
|
|
218 |
result = tagger(sentence_2) |
|
|
219 |
print(result) |
|
|
220 |
if result[1] > 0.5: |
|
|
221 |
local_results_adverse.append(sentence_) |
|
|
222 |
else: |
|
|
223 |
tagger.predict(sentence) |
|
|
224 |
for annotation_layer in sentence.annotation_layers.keys(): |
|
|
225 |
for entity in sentence.get_spans(annotation_layer): |
|
|
226 |
if dic[j] == 'Chemicals': |
|
|
227 |
local_results_chemicals.append(entity.text) |
|
|
228 |
detected_chemicals = True |
|
|
229 |
entity_chemical = entity.text |
|
|
230 |
elif dic[j] == 'Diseases': |
|
|
231 |
local_results_diseases.append(entity.text) |
|
|
232 |
elif dic[j] == 'Dates': |
|
|
233 |
if entity.tag == 'DATE': |
|
|
234 |
local_results_dates.append(entity.text) |
|
|
235 |
elif dic[j] == 'Doses': |
|
|
236 |
if entity.tag == 'QUANTITY': |
|
|
237 |
local_results_doses.append(entity.text) |
|
|
238 |
if detected_chemicals: |
|
|
239 |
print('YES') |
|
|
240 |
print(detected_chemicals) |
|
|
241 |
print(sentence) |
|
|
242 |
dic_doses_chemicals[entity.text] = entity_chemical |
|
|
243 |
else: |
|
|
244 |
dic_doses_chemicals[entity.text] = 'unknown' |
|
|
245 |
detected_chemicals = False |
|
|
246 |
for j, c in enumerate(choices): |
|
|
247 |
if c: |
|
|
248 |
if dic[j] == 'Chemicals': |
|
|
249 |
# next line is to avoid detecting some characters as 'drugs' (happened sometimes) |
|
|
250 |
local_results_chemicals = [x for x in local_results_chemicals if x not in ['(', ')', '[', ']', '{', '}', ' ', '']] |
|
|
251 |
results[dic[j]] = list(set(local_results_chemicals)) |
|
|
252 |
elif dic[j] == 'Diseases': |
|
|
253 |
results[dic[j]] = list(set(local_results_diseases)) |
|
|
254 |
elif dic[j] == 'Dates': |
|
|
255 |
results[dic[j]] = list(set(local_results_dates)) |
|
|
256 |
elif dic[j] == 'Adverse effects': |
|
|
257 |
results[dic[j]] = list(set(local_results_adverse)) |
|
|
258 |
elif dic[j] == 'Doses': |
|
|
259 |
results[dic[j]] = list(set(local_results_doses)) |
|
|
260 |
|
|
|
261 |
|
|
|
262 |
streamlit.write('Done!') |
|
|
263 |
return results |
|
|
264 |
|
|
|
265 |
|
|
|
266 |
|
|
|
267 |
def higlight(filename: str, choices: list[bool]): |
|
|
268 |
''' Highlight the entities chosen in the pdf file whose name is 'filename' ''' |
|
|
269 |
|
|
|
270 |
root = './NER-Medical-Document/processed_files/' + filename[:-4] + '/' |
|
|
271 |
pdf_input = fitz.open(root+filename) |
|
|
272 |
results = extraction(filename, choices) |
|
|
273 |
streamlit.write('Highlighting entities...') |
|
|
274 |
text_instances = {} |
|
|
275 |
|
|
|
276 |
# go through all the pages of the chosen pdf file |
|
|
277 |
for page in pdf_input: |
|
|
278 |
|
|
|
279 |
# search all the occurences of the entities in the page |
|
|
280 |
for name, entities in results.items(): |
|
|
281 |
text_instances[name] = [page.search_for(text) for text in entities] |
|
|
282 |
|
|
|
283 |
for name, instances in text_instances.items(): |
|
|
284 |
add_definition = False |
|
|
285 |
if name == 'Chemicals': |
|
|
286 |
color = (1, 1, 0) |
|
|
287 |
add_definition = True |
|
|
288 |
elif name == 'Diseases': |
|
|
289 |
color = (0, 1, 0) |
|
|
290 |
add_definition = True |
|
|
291 |
elif name == 'Dates': |
|
|
292 |
color = (0, 0.7, 1) |
|
|
293 |
add_definition = False |
|
|
294 |
elif name == 'Adverse effects': |
|
|
295 |
color = (1, 0, 0) |
|
|
296 |
add_definition = False |
|
|
297 |
elif name == 'Doses': |
|
|
298 |
color = (1, 0, 1) |
|
|
299 |
add_definition = False |
|
|
300 |
|
|
|
301 |
# highlight each occurence of the entity in the page |
|
|
302 |
for i, inst in enumerate(instances): |
|
|
303 |
for x in inst: |
|
|
304 |
# handle the case where an entity should not be highlighted (see README): the idea is too check the surrounding characters to |
|
|
305 |
# detect if the occurence of the entity is part of another word or not |
|
|
306 |
|
|
|
307 |
# check the typical distance between 2 letters in the word (because it depends on the font size) |
|
|
308 |
dist_letters = (x[2]-x[0])/len(results[name][i]) |
|
|
309 |
# draw a larger rectangle to check the surrounding characters |
|
|
310 |
rect_larger = fitz.Rect(x[0]-dist_letters, x[1], x[2]+dist_letters, x[3]) |
|
|
311 |
non_accepted_chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' |
|
|
312 |
word = page.get_textbox(rect_larger).lower() |
|
|
313 |
for sub in results[name][i].split(): |
|
|
314 |
word = word.replace(sub.lower(), '') |
|
|
315 |
keep = True |
|
|
316 |
for l in word: |
|
|
317 |
if l in non_accepted_chars: |
|
|
318 |
keep = False |
|
|
319 |
if not keep: |
|
|
320 |
continue # ignore the occurence of the entity if it is part of another word |
|
|
321 |
|
|
|
322 |
annot = page.add_highlight_annot(x) |
|
|
323 |
annot.set_colors({"stroke": color}) |
|
|
324 |
annot.set_opacity(0.4) |
|
|
325 |
if add_definition: |
|
|
326 |
try: |
|
|
327 |
annot.set_popup(x) |
|
|
328 |
info = annot.info |
|
|
329 |
info["title"] = "Definition" |
|
|
330 |
if name == 'Chemicals': |
|
|
331 |
info["content"] = wikipedia.summary(results[name][i] + ' (drug)').split('.')[0] |
|
|
332 |
else: |
|
|
333 |
info["content"] = wikipedia.summary(results[name][i] + f' ({name.lower()[:-1]})').split('.')[0] |
|
|
334 |
annot.set_info(info) |
|
|
335 |
except: |
|
|
336 |
pass |
|
|
337 |
annot.update() |
|
|
338 |
|
|
|
339 |
if os.path.exists(root+filename[:-4]+'_highlighted.pdf'): |
|
|
340 |
os.remove(root+filename[:-4]+'_highlighted.pdf') |
|
|
341 |
pdf_input.save(root+filename[:-4]+'_highlighted.pdf') |
|
|
342 |
streamlit.write('Done!') |
|
|
343 |
|
|
|
344 |
|
|
|
345 |
|
|
|
346 |
|
|
|
347 |
|
|
|
348 |
if __name__ == '__main__': |
|
|
349 |
print(extraction('0.txt', [True, True, True, True, True], use_streamlit=False)) |