--- a +++ b/evaluation.py @@ -0,0 +1,349 @@ +from flair.data import Sentence +from flair.models import SequenceTagger, TextClassifier +from flair.tokenization import SciSpacyTokenizer +from transformers import pipeline, TextClassificationPipeline, AutoTokenizer, TFBertForTokenClassification, BertForSequenceClassification, AutoModelForSequenceClassification +from transformers.trainer import Trainer, TrainingArguments +from stqdm import stqdm +from allennlp.predictors.predictor import Predictor +import os +import fitz +import streamlit +import wikipedia +import nltk.data +import os + + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +class InferenceADE: + ''' Voting classifier using 3 different models for ADE detection ''' + + def __init__(self, pipeline_scibert, pipeline_biolink, model_hunflair): + self.f1_score_biolink = 0.96 # not real f1 scores for now + self.f1_score_scibert = 0.80 + self.f1_score_hunflair = 0.90 + self.pipeline_scibert = pipeline_scibert + self.pipeline_biolink = pipeline_biolink + self.model_hunflair = model_hunflair + + def __call__(self, sentence): + + result_bert = self.pipeline_scibert(sentence)[0] + result_biolink = self.pipeline_biolink(sentence)[0] + s = Sentence(sentence) + self.model_hunflair.predict(s) + result_hunflair = s.labels[0].to_dict() + + if result_bert['label'] == 'LABEL_0': + pred_scibert = [result_bert['score'], 1-result_bert['score']] + elif result_bert['label'] == 'LABEL_1': + pred_scibert = [1-result_bert['score'], result_bert['score']] + + if result_biolink['label'] == 'LABEL_0': + pred_biolink = [result_biolink['score'], 1-result_biolink['score']] + elif result_biolink['label'] == 'LABEL_1': + pred_biolink = [1-result_biolink['score'], result_biolink['score']] + + if result_hunflair['value'] == '0': + pred_hunflair = [result_hunflair['confidence'], 1-result_hunflair['confidence']] + elif result_hunflair['value'] == '1': + pred_hunflair = [1-result_hunflair['confidence'], result_hunflair['confidence']] + + # voting classifier + + weighted_average_1 = float((self.f1_score_biolink * pred_biolink[0] + self.f1_score_scibert * pred_scibert[0] + self.f1_score_hunflair * pred_hunflair[0]) / (self.f1_score_biolink + self.f1_score_scibert + self.f1_score_hunflair)) + weighted_average_2 = float((self.f1_score_biolink * pred_biolink[1] + self.f1_score_scibert * pred_scibert[1] + self.f1_score_hunflair * pred_hunflair[1]) / (self.f1_score_biolink + self.f1_score_scibert + self.f1_score_hunflair)) + + return [weighted_average_1, weighted_average_2] + + + +def extraction(filename: str, choices: list[bool], use_streamlit: bool = True): + ''' Takes as input the name of a pdf file and extract the wanted entities from it + Outputs a dictionary with the entities' names as keys and the list of entities as values + ''' + tokenizer_split_sentences = nltk.data.load('tokenizers/punkt/english.pickle') + root = './NER-Medical-Document/processed_files/' + filename[:-4] + '/' + results = {} + models = {} + limit = 20 # maximum number of files to process, used for testing if the pdf file is too big + ind = 0 + # only add the necessary models to the dictionary 'models' (avoid unnecessary loading) + model_url = 'https://storage.googleapis.com/allennlp-public-models/coref-spanbert-large-2021.03.10.tar.gz' + predictor = Predictor.from_path(model_url) + + if choices[0]: + tagger_chemicals: SequenceTagger = SequenceTagger.load('./NER-Medical-Document/training_results/best-model.pt') + models[0] = tagger_chemicals + + if choices[1]: + tagger_diseases: SequenceTagger = SequenceTagger.load('hunflair-disease') + models[1] = tagger_diseases + + if choices[2]: + tagger_dates = SequenceTagger.load("flair/ner-english-ontonotes-fast") + models[2] = tagger_dates + + if choices[3]: + # 3 different methods for ADE detection + method = 2 + + if method == 1: + # Use a token classification model (classify a token as an ADE, not a sentence) + model_adverse_name = "abhibisht89/spanbert-large-cased-finetuned-ade_corpus_v2" # model name from huggingface.co/models + model_adverse = TFBertForTokenClassification.from_pretrained(model_adverse_name, from_pt=True) + tokenizer_adverse = AutoTokenizer.from_pretrained(model_adverse_name) + models[3] = pipeline("token-classification", model = model_adverse, tokenizer = tokenizer_adverse, grouped_entities=True) + + elif method == 2: + # Sentence classification: use HunFlair model + negation detection + tokenizer_neg = AutoTokenizer.from_pretrained("bvanaken/clinical-assertion-negation-bert") + model_neg = AutoModelForSequenceClassification.from_pretrained("bvanaken/clinical-assertion-negation-bert") + pipeline_neg = TextClassificationPipeline(model=model_neg, tokenizer=tokenizer_neg) + model_hunflair = TextClassifier.load('./NER-Medical-Document/training_results/flair_bert/best-model.pt') + models[3] = model_hunflair + + elif method == 3: + # Sentence classification: use the InferenceADE class to design a voting classifier + model_scibert_name = 'NER-Medical-Document/training_results/scibert_scivocab_uncased' + model_scibert = BertForSequenceClassification.from_pretrained(model_scibert_name) + tokenizer_scibert = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased') + pipeline_scibert = pipeline("text-classification", model = model_scibert, tokenizer = tokenizer_scibert) + + model_biolink_name = 'NER_Medical-Document/training_results/BioLinkBERT-base' + model_biolink = BertForSequenceClassification.from_pretrained(model_biolink_name) + tokenizer_biolink = AutoTokenizer.from_pretrained('michiyasunaga/BioLinkBERT-base') + pipeline_biolink = pipeline("text-classification", model = model_biolink, tokenizer = tokenizer_biolink) + + model_hunflair = TextClassifier.load('./NER-Medical-Document/training_results/flair_bert/best-model.pt') + + models[3] = InferenceADE(pipeline_scibert, pipeline_biolink, model_hunflair) + + if choices[4]: + tagger_doses = SequenceTagger.load("flair/ner-english-ontonotes-fast") + models[4] = tagger_doses + + + dic = {0: 'Chemicals', 1: 'Diseases', 2: 'Dates', 3: 'Adverse effects', 4: 'Doses'} + + info = 'Extracting ' + for j, c in enumerate(choices): + if c: + info += dic[j].lower() + ' and ' + streamlit.write(info[:-5] + ' entities...') + + local_results_chemicals = [] + local_results_diseases = [] + local_results_dates = [] + local_results_adverse = [] + local_results_doses = [] + dic_doses_chemicals = {} + + for i, file in enumerate(os.listdir(root)): + if file.endswith('.txt') and ind < limit: + ind += 1 + with open(root+file, 'r') as f: + paragraphs = f.read().split('\n\n') + sentences = [] + for p in paragraphs: + sentences.extend(tokenizer_split_sentences.tokenize(p)) + sentences[-1] = sentences[-1]+ '\n\n' + + + if not use_streamlit: + size_window = 3 + list_coref = [sentences[i] for i in range(size_window)] + try: + prediction = predictor.predict(document=' '.join(list_coref)) + except: + pass + + for s in stqdm(range(len(sentences))): + sentence_ = sentences[s] + + ### Coreference resolution ### + if not use_streamlit: + if s >= size_window: + list_coref.append(sentence_) + list_coref.pop(0) + try: + prediction = predictor.predict(document=' '.join(list_coref)) + transformed_chunk = predictor.coref_resolved(' '.join(list_coref)) + paragraphs2 = transformed_chunk.split('\n\n') + for p in paragraphs2: + sentences2 = tokenizer_split_sentences.tokenize(p) + sentence_transformed = sentences2[-1] + except: + sentence_transformed = sentence_ + pass + else: + sentence_transformed = sentence_ + + sentence_ = sentence_.replace('\n', ' ') + if len(sentence_) >= 4: + print(sentence_) + sentence = Sentence(sentence_, use_tokenizer=SciSpacyTokenizer()) + for j, c in enumerate(choices): + if c: + tagger = models[j] + if dic[j] == 'Adverse effects': + if method == 1: + result = tagger(sentence_) + if result != []: + for entity in tagger(sentence_): + if entity['entity_group'] == 'ADR': + local_results_dates.append(entity['word']) + elif method == 2: + sentence = Sentence(sentence_, use_tokenizer=SciSpacyTokenizer()) # create new instance of Sentence + tagger.predict(sentence) + result = sentence.labels[0].to_dict() + if result['value'] == '1': + if not pipeline_neg(sentence_)[0]['label'] == 'ABSENT': + print('DETECTED') + local_results_adverse.append(sentence_) + elif method == 3: + result = tagger(sentence_) + print(result) + if result[1] > 0.5: + local_results_adverse.append(sentence_) + else: + models[0].predict(sentence) + found = False + for annotation_layer in sentence.annotation_layers.keys(): + for entity in sentence.get_spans(annotation_layer): + found = True + sentence_2 = sentence_.replace(entity.text, 'aspirin') + if found: + result = tagger(sentence_2) + print(result) + if result[1] > 0.5: + local_results_adverse.append(sentence_) + else: + tagger.predict(sentence) + for annotation_layer in sentence.annotation_layers.keys(): + for entity in sentence.get_spans(annotation_layer): + if dic[j] == 'Chemicals': + local_results_chemicals.append(entity.text) + detected_chemicals = True + entity_chemical = entity.text + elif dic[j] == 'Diseases': + local_results_diseases.append(entity.text) + elif dic[j] == 'Dates': + if entity.tag == 'DATE': + local_results_dates.append(entity.text) + elif dic[j] == 'Doses': + if entity.tag == 'QUANTITY': + local_results_doses.append(entity.text) + if detected_chemicals: + print('YES') + print(detected_chemicals) + print(sentence) + dic_doses_chemicals[entity.text] = entity_chemical + else: + dic_doses_chemicals[entity.text] = 'unknown' + detected_chemicals = False + for j, c in enumerate(choices): + if c: + if dic[j] == 'Chemicals': + # next line is to avoid detecting some characters as 'drugs' (happened sometimes) + local_results_chemicals = [x for x in local_results_chemicals if x not in ['(', ')', '[', ']', '{', '}', ' ', '']] + results[dic[j]] = list(set(local_results_chemicals)) + elif dic[j] == 'Diseases': + results[dic[j]] = list(set(local_results_diseases)) + elif dic[j] == 'Dates': + results[dic[j]] = list(set(local_results_dates)) + elif dic[j] == 'Adverse effects': + results[dic[j]] = list(set(local_results_adverse)) + elif dic[j] == 'Doses': + results[dic[j]] = list(set(local_results_doses)) + + + streamlit.write('Done!') + return results + + + +def higlight(filename: str, choices: list[bool]): + ''' Highlight the entities chosen in the pdf file whose name is 'filename' ''' + + root = './NER-Medical-Document/processed_files/' + filename[:-4] + '/' + pdf_input = fitz.open(root+filename) + results = extraction(filename, choices) + streamlit.write('Highlighting entities...') + text_instances = {} + + # go through all the pages of the chosen pdf file + for page in pdf_input: + + # search all the occurences of the entities in the page + for name, entities in results.items(): + text_instances[name] = [page.search_for(text) for text in entities] + + for name, instances in text_instances.items(): + add_definition = False + if name == 'Chemicals': + color = (1, 1, 0) + add_definition = True + elif name == 'Diseases': + color = (0, 1, 0) + add_definition = True + elif name == 'Dates': + color = (0, 0.7, 1) + add_definition = False + elif name == 'Adverse effects': + color = (1, 0, 0) + add_definition = False + elif name == 'Doses': + color = (1, 0, 1) + add_definition = False + + # highlight each occurence of the entity in the page + for i, inst in enumerate(instances): + for x in inst: + # handle the case where an entity should not be highlighted (see README): the idea is too check the surrounding characters to + # detect if the occurence of the entity is part of another word or not + + # check the typical distance between 2 letters in the word (because it depends on the font size) + dist_letters = (x[2]-x[0])/len(results[name][i]) + # draw a larger rectangle to check the surrounding characters + rect_larger = fitz.Rect(x[0]-dist_letters, x[1], x[2]+dist_letters, x[3]) + non_accepted_chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' + word = page.get_textbox(rect_larger).lower() + for sub in results[name][i].split(): + word = word.replace(sub.lower(), '') + keep = True + for l in word: + if l in non_accepted_chars: + keep = False + if not keep: + continue # ignore the occurence of the entity if it is part of another word + + annot = page.add_highlight_annot(x) + annot.set_colors({"stroke": color}) + annot.set_opacity(0.4) + if add_definition: + try: + annot.set_popup(x) + info = annot.info + info["title"] = "Definition" + if name == 'Chemicals': + info["content"] = wikipedia.summary(results[name][i] + ' (drug)').split('.')[0] + else: + info["content"] = wikipedia.summary(results[name][i] + f' ({name.lower()[:-1]})').split('.')[0] + annot.set_info(info) + except: + pass + annot.update() + + if os.path.exists(root+filename[:-4]+'_highlighted.pdf'): + os.remove(root+filename[:-4]+'_highlighted.pdf') + pdf_input.save(root+filename[:-4]+'_highlighted.pdf') + streamlit.write('Done!') + + + + + +if __name__ == '__main__': + print(extraction('0.txt', [True, True, True, True, True], use_streamlit=False))