Switch to side-by-side view

--- a
+++ b/src/Parser/entity_recognition.py
@@ -0,0 +1,534 @@
+
+import os
+import joblib
+from joblib import delayed
+from tqdm.auto import tqdm
+import requests
+from typing import List, Dict, Union
+import numpy as np
+import pandas as pd
+import glob
+import json
+
+import torch
+import medspacy
+import spacy
+from spacy.matcher import PhraseMatcher
+from spacy.tokens import Span
+from spacy.language import Language
+from spacy.util import filter_spans
+from spacy.tokens import Doc, Token
+from spacy.matcher import Matcher
+from srsly import read_json
+import re
+import transformers
+from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
+import warnings
+
+# Filepaths
+INPUT_FILEPATH = '/home/mabdallah/TrialMatchAI/data/preprocessed_data'
+OUTPUT_FILEPATH_CT = '/home/mabdallah/TrialMatchAI/data/ner_clinical_trials/'
+# OUTPUT_FILEPATH_PAT = "../data/ner_patients_clinical_notes/"
+
+# List of auxiliary entities
+AUXILIARY_ENTITIES_LIST = ["Sign_symptom", "Biological_structure", "Date", "Duration", "Time", "Frequency", 
+                           "Severity", "Lab_value", "Dosage", "Diagnostic_procedure", "Therapeutic_procedure", "Medication", 
+                           "Clinical_event", "Outcome", "History", "Subject", "Family_history", "Detailed_description", "Area"]
+
+# Check if CUDA is available
+device0 = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+device1 = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
+device2 = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
+# Load auxiliary tokenizer and pipeline
+aux_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all", model_max_length=512, max_length=512, truncation=True)
+aux_pipeline = pipeline("ner", model="d4data/biomedical-ner-all", tokenizer=aux_tokenizer, aggregation_strategy="first", device=device0)
+
+# Load mutations tokenizer and pipeline
+mutations_tokenizer = AutoTokenizer.from_pretrained("Brizape/tmvar-PubMedBert-finetuned-24-02", model_max_length=512, max_length=512, truncation=True)
+mutations_pipeline = pipeline("ner", model="Brizape/tmvar-PubMedBert-finetuned-24-02", tokenizer=mutations_tokenizer, aggregation_strategy="first",  device=device1)
+
+neg_tokenizer = AutoTokenizer.from_pretrained("bvanaken/clinical-assertion-negation-bert", model_max_length=512, max_length=512, truncation=True)
+neg_model = AutoModelForSequenceClassification.from_pretrained("bvanaken/clinical-assertion-negation-bert")
+neg_classifier = pipeline("text-classification", model=neg_model, tokenizer=neg_tokenizer, device=device2)
+
+
+def query_plain(text, url="http://localhost:8888/plain"):
+    """
+    Send a plain text query to a specified URL.
+
+    This function sends a plain text query to a specified URL using the POST method. The query is sent as a JSON object with the 'text' key.
+    The response is received as a JSON object and is decoded into a string.
+
+    Parameters:
+        text (str): The plain text query to be sent.
+        url (str): The URL to which the query is sent. Default is "http://localhost:8888/plain".
+
+    Returns:
+        dict: The response received as a JSON object.
+
+    Example:
+        query_plain("Hello, world!")
+        # Output: {'response': 'Hello, world!'}
+    """
+    return json.loads(requests.post(url, json={'text': text}).content.decode('utf-8'))
+
+# Memory caching for function calls
+memory = joblib.Memory(".")
+
+def ParallelExecutor(use_bar="tqdm", **joblib_args):
+    """
+    Utility function for tqdm progress bar in joblib.Parallel.
+
+    This function is a utility for using tqdm progress bar with joblib.Parallel. It returns a function that can be used as a wrapper
+    for the operation iterator in joblib.Parallel. The function takes a 'bar' argument which specifies the type of progress bar to use.
+    The available options are 'tqdm', 'False', and 'None'. The function also accepts additional arguments that are passed to tqdm.
+
+    Parameters:
+        use_bar (str): The type of progress bar to use. Default is "tqdm".
+        **tq_args: Additional arguments to be passed to tqdm.
+
+    Returns:
+        function: The wrapper function that can be used with joblib.Parallel.
+
+    Example:
+        executor = ParallelExecutor(use_bar="tqdm", ncols=80)
+        results = executor(op_iter)
+    """
+    all_bar_funcs = {
+        "tqdm": lambda args: lambda x: tqdm(x, **args),
+        "False": lambda args: lambda x: x,
+        "None": lambda args: lambda x: x,
+    }
+
+    def aprun(bar=use_bar, **tq_args):
+        def tmp(op_iter):
+            if str(bar) in all_bar_funcs.keys():
+                bar_func = all_bar_funcs[str(bar)](tq_args)
+            else:
+                raise ValueError("Value %s not supported as bar type" % bar)
+            # Pass n_jobs from joblib_args
+            return joblib.Parallel(n_jobs=joblib_args.get("n_jobs", 10))(bar_func(op_iter))
+
+        return tmp
+
+    return aprun
+
+def get_dictionaries_of_specific_entities(list_of_dicts, key, values):
+    """
+    Filter a list of dictionaries based on the presence of specific values in a specified key.
+
+    This function takes a list of dictionaries and filters them based on the presence of specific values in a specified key.
+    The function checks each dictionary in the input list and includes only those dictionaries where any of the given values
+    are present in the specified key. The filtering is performed using list comprehensions.
+
+    Parameters:
+        list_of_dicts (list): A list of dictionaries to be filtered.
+        key (str): The key in the dictionaries where the filtering is applied.
+        values (list): A list of values. The function will filter dictionaries where any of these values are present in the specified key.
+
+    Returns:
+        list: A list of dictionaries that meet the filtering criteria.
+
+    Example:
+        list_of_dicts = [
+            {"name": "Alice", "age": 30},
+            {"name": "Bob", "age": 25},
+            {"name": "Charlie", "age": 35},
+            {"name": "David", "age": 30},
+        ]
+
+        get_dictionaries_of_specific_entities(list_of_dicts, "age", [30, 35])
+        # Output: [
+        #   {"name": "Alice", "age": 30},
+        #   {"name": "Charlie", "age": 35},
+        #   {"name": "David", "age": 30}
+        # ]
+    """
+    return [d for d in list_of_dicts if any(val in d.get(key, []) for val in values)]
+
+def add_custom_entity(doc, entity):
+    """
+    Add a custom entity to a spaCy document.
+
+    This function takes a spaCy document and a custom entity dictionary and adds the custom entity to the document.
+    The function finds the token indices corresponding to the character span of the entity and sets the entity span in the document.
+
+    Parameters:
+        doc (spacy.tokens.Doc): The spaCy document to which the entity is added.
+        entity (dict): The custom entity dictionary containing the text, start, end, and entity_group.
+
+    Returns:
+        spacy.tokens.Doc: The modified spaCy document with the custom entity added.
+
+    Example:
+        doc = nlp("The patient has a fever.")
+        entity = {"text": "fever", "start": 16, "end": 21, "entity_group": "Symptom"}
+        doc = add_custom_entity(doc, entity)
+    """
+    entity["text"] = re.sub(r'([,.-])\s+', r'\1', entity["text"]) 
+    entity_text = entity["text"].lower()
+    start_char = entity["start"] 
+    end_char = entity["end"] 
+    # Find the token indices corresponding to the character span
+    start_indices = [i for i, token in enumerate(doc) if (start_char <= token.idx <= end_char) or (entity_text in token.text and token.idx <= start_char)]
+    if start_indices:
+        # You can choose the first matching window or handle multiple matches
+        start_index = start_indices[0]
+        start_token = doc[start_index]
+        end_index = min(start_index + len(entity_text.split()) - 1, len(doc) - 1)
+        end_token = doc[end_index]
+        doc.set_ents([Span(doc, start_token.i, end_token.i + 1, entity["entity_group"])])
+        
+    return doc
+
+
+def negation_handling(sentence, entity):
+    """
+    Perform negation handling on a sentence with a given entity.
+
+    This function takes a sentence and an entity dictionary and performs negation handling on the sentence.
+    The function uses medSpaCy to identify negation cues and determines if the entity is negated or not.
+
+    Parameters:
+        sentence (str): The sentence in which the entity is present.
+        entity (dict): The entity dictionary containing the text, start, and end.
+
+    Returns:
+        dict: The modified entity dictionary with the "is_negated" field indicating if the entity is negated or not.
+
+    Example:
+        sentence = "The patient does not have a fever."
+        entity = {"text": "fever", "start": 23, "end": 28}
+        entity = negation_handling(sentence, entity)
+    """
+    nlp = spacy.load("en_core_web_sm", disable={"ner"})
+    doc = nlp(sentence.lower())
+    nlp = medspacy.load(nlp)
+    nlp.disable_pipe('medspacy_target_matcher')
+    nlp.disable_pipe('medspacy_pyrush')
+    doc = nlp(add_custom_entity(doc, entity))
+    for e in doc.ents:
+        rs = str(e._.is_negated)
+        if rs:
+            if rs == "True": 
+                entity["is_negated"] = "yes"
+            elif rs == 'False':
+                entity["is_negated"] = "no"
+        else:
+            entity["is_negated"] = "no"
+    return entity
+
+def is_entity_negated(sentence, entity):
+    # Surround the entity with [entity] on both sides
+    entity_text = entity["text"]
+    sentence_with_entity = re.sub(rf'\b{re.escape(entity_text)}\b', f"[entity]{entity_text}[entity]", sentence)
+
+    # Classify the modified sentence to check for negation
+    classification = neg_classifier(sentence_with_entity, max_length=512, truncation=True)[0]
+    is_negated = classification['label'] == 'ABSENT'
+    if is_negated:
+        entity["is_negated"] = "yes"
+    else:
+        entity["is_negated"] = "no"
+    return entity
+
+class EntityRecognizer:
+    def __init__(self, id_list, n_jobs, data_source="clinical trials"):
+        self.id_list = id_list
+        self.n_jobs = n_jobs
+        self.data_source = data_source
+        
+    def data_loader(self, id_list):
+        to_concat = []
+        for idx in id_list:
+            if self.data_source == "clinical trials":
+                file_path = os.path.join(INPUT_FILEPATH, "clinical_trials", f"{idx}_preprocessed.csv")
+                if os.path.exists(file_path):
+                    df = pd.read_csv(file_path)
+                    to_concat.append(df)
+            elif self.data_source=="patient notes":
+                df = pd.read_csv(INPUT_FILEPATH + "patient_notes/" + "%s_preprocessed.csv"%idx)
+                to_concat.append(df)
+            else:
+                warnings.warn("Unexpected data source encountered. Please choose between 'clinical trials' or 'patient notes'", UserWarning)
+        return to_concat
+    
+    def mtner_normalize_format(self, json_data):
+        spacy_format_entities = []
+        for annotation in json_data["annotations"]:
+            start = annotation["span"]["begin"]
+            end = annotation["span"]["end"]
+            label = annotation["obj"]
+            mention = annotation["mention"]
+            score = annotation["prob"]
+            normalized_id = annotation["id"]
+            spacy_format_entities.append({
+                "entity_group": label,
+                "text": mention,
+                "score": score,
+                "start": start,
+                "end": end,
+                "normalized_id": normalized_id
+            })
+        spacy_result = {
+            "text": json_data["text"],
+            "ents": spacy_format_entities,
+        }
+        return spacy_result
+    
+    def merge_lists_with_priority_to_first(self, list1, list2):
+        merged_list = list1.copy()  
+        for dict2 in list2:
+            overlap = False
+            for dict1 in list1:
+                if (dict1['start'] <= dict2['end'] and dict2['start'] <= dict1['start']) or (dict2['start'] <= dict1['end'] and dict1['start'] <= dict2['start']):
+                    overlap = True
+                    break
+            
+            if not overlap:
+                merged_list.append(dict2)
+        return merged_list
+    
+    def merge_lists_without_priority(self, list1, list2):
+        merged_list = list1.copy()  
+        for dict2 in list2:
+            merged_list.append(dict2)
+        return merged_list
+    
+    def find_and_remove_overlaps(self, dictionary_list, if_overlap_keep):
+        # Create a dictionary to store non-overlapping entries
+        non_overlapping = {}
+        # Create a set of entity groups to keep
+        preferred_set = set(if_overlap_keep)
+
+        # Iterate through the input list
+        for entry in dictionary_list:
+            if 'text' in entry and 'entity_group' in entry:
+                text = entry['text']
+                group = entry['entity_group']
+
+                # Check if the text is already in the non_overlapping dictionary
+                if text in non_overlapping:
+                    # Compare groups and keep the entry if it belongs to one of the preferred groups
+                    if group in preferred_set:
+                        non_overlapping[text] = entry
+                else:
+                    non_overlapping[text] = entry
+
+        # Convert the non-overlapping dictionary back to a list
+        result_list = list(non_overlapping.values())
+
+        return result_list
+    
+    def aberration_type_recognizer(self, text):
+        med_nlp = medspacy.load()
+        med_nlp.disable_pipe('medspacy_target_matcher')
+        @Language.component("aberrations-ner")
+        def regex_pattern_matcher_for_aberrations(doc):
+            df_regex = pd.read_csv("/home/mabdallah/TrialMatchAI/data/regex_variants.tsv", sep="\t", header=None)
+            df_regex = df_regex.rename(columns={1 : "label", 2:"regex_pattern"}).drop(columns=[0])
+            dict_regex = df_regex.set_index('label')['regex_pattern'].to_dict()
+            original_ents = list(doc.ents)
+            # Compile the regex patterns
+            compiled_patterns = {
+                label: re.compile(pattern)
+                for label, pattern in dict_regex.items()
+            }
+            mwt_ents = []
+            for label, pattern in compiled_patterns.items():
+                for match in re.finditer(pattern, doc.text):
+                    start, end = match.span()
+                    span = doc.char_span(start, end)
+                    if span is not None:
+                        mwt_ents.append((label, span.start, span.end, span.text))
+                        
+            for ent in mwt_ents:
+                label, start, end, name = ent
+                per_ent = Span(doc, start, end, label=label)
+                original_ents.append(per_ent)
+
+            doc.ents = filter_spans(original_ents)
+            
+            return doc
+        med_nlp.add_pipe("aberrations-ner", before='medspacy_context')
+        doc = med_nlp(text)
+        ent_list =[] 
+        for entity in doc.ents:
+            ent_list.append({"entity_group" : entity.label_, 
+                            "text" : entity.text, 
+                            "start": entity.start_char, 
+                            "end": entity.end_char, 
+                            "is_negated" : "yes" if entity._.is_negated else "no"})
+        return ent_list
+    
+
+    def pregnancy_recognizer(self, text):
+        med_nlp = medspacy.load()
+        med_nlp.disable_pipe('medspacy_target_matcher')
+        
+        # Updated regex pattern
+        regex_pattern = r"(?i)\b(?:pregn\w+|matern\w+|gestat\w+|lactat\w+|breastfeed\w+|prenat\w+|antenat\w+|postpartum|childbear\w+|parturient|conceiv\w+|obstetr\w+)\b"
+
+        @Language.component("pregnancy-ner")
+        def regex_pattern_matcher_for_pregnancy(doc):
+            compiled_pattern = re.compile(regex_pattern)
+
+            original_ents = list(doc.ents)
+            mwt_ents = []
+
+            for match in re.finditer(compiled_pattern, doc.text):
+                start, end = match.span()
+                span = doc.char_span(start, end)
+                if span is not None:
+                    mwt_ents.append((span.start, span.end, span.text))
+
+            for ent in mwt_ents:
+                start, end, name = ent
+                per_ent = Span(doc, start, end, label="pregnancy")  # Assigning the label "pregnancy"
+                original_ents.append(per_ent)
+
+            doc.ents = filter_spans(original_ents)
+
+            return doc
+
+        med_nlp.add_pipe("pregnancy-ner", before='medspacy_context')
+        doc = med_nlp(text)
+        
+        ent_list =[] 
+        for entity in doc.ents:
+            ent_list.append({
+                "entity_group": entity.label_,
+                "text": entity.text,
+                "start": entity.start_char,
+                "end": entity.end_char,
+                "is_negated": "yes" if entity._.is_negated else "no"
+            })
+    
+        return ent_list
+    
+    def merge_similar_consecutive_entities(self, entities):
+        combined_entities = []
+        if entities:
+            current_entity = entities[0]
+            for next_entity in entities[1:]:
+                if (
+                    'text' in current_entity
+                    and 'text' in next_entity
+                    and 'entity_group' in current_entity
+                    and 'entity_group' in next_entity
+                    and 'start' in current_entity
+                    and 'end' in current_entity
+                    and 'start' in next_entity
+                    and 'end' in next_entity
+                    and current_entity['entity_group'] == next_entity['entity_group']
+                    and next_entity['start'] - current_entity['end'] - 1 <= 3
+                ):
+                    current_entity['text'] += ' ' + next_entity['text']
+                    current_entity['end'] = next_entity['end']
+                else:
+                    combined_entities.append(current_entity)
+                    current_entity = next_entity
+
+            combined_entities.append(current_entity)
+        return combined_entities
+
+    def recognize_entities(self, df):
+        _ids = []
+        sentences = []
+        entities_groups = []
+        entities_texts = []
+        normalized_ids = []
+        is_negated = []
+        field = []
+        start= []
+        end = []
+        df = df.dropna()
+        for _,row in df.iterrows():
+            sent = row["sentence"].replace(",", "")
+            main_entities = self.mtner_normalize_format(query_plain(sent))["ents"]
+            variants_entities = mutations_pipeline(sent)
+            aberration_type_entities = self.aberration_type_recognizer(sent)
+            pregnancy_entities = self.pregnancy_recognizer(sent)
+            aux_entities = aux_pipeline(sent)
+            aux_entities = get_dictionaries_of_specific_entities(aux_entities, "entity_group", AUXILIARY_ENTITIES_LIST)
+            aux_entities = [{"text" if k == "word" else k: v for k, v in d.items()} for d in aux_entities]
+            
+            combined_entities  = self.merge_lists_with_priority_to_first(variants_entities, main_entities)
+            combined_entities  = self.merge_lists_with_priority_to_first(combined_entities, aux_entities)
+            combined_entities  = self.merge_lists_without_priority(combined_entities, pregnancy_entities)
+            combined_entities  = self.merge_lists_with_priority_to_first(combined_entities, aberration_type_entities)
+            combined_entities  = self.merge_similar_consecutive_entities(combined_entities)
+            
+            # Convert the selected_entries dictionary back to a list
+            if len(combined_entities) > 0:
+                # clean_entities = self.find_and_remove_overlaps(combined_entities, if_overlap_keep=["gene", "ProteinMutation", "DNAMutation", "SNP"])
+                for e in combined_entities:
+                    if 'text' in e and 'entity_group' in e:
+                        if (("score" in e and e["score"] > 0.7) or ("score" not in e)) and len(e["text"]) > 1:
+                            ent = is_entity_negated(sent, e)
+                            ent["text"] = re.sub(r'([,.-])\s+', r'\1', e["text"]) 
+                            is_negated.append(ent["is_negated"]) 
+                            _ids.append(row["id"])
+                            sentences.append(sent)
+                            entities_groups.append(ent['entity_group'])
+                            entities_texts.append(ent['text'])
+                            start.append(ent["start"])
+                            end.append(ent["end"])
+                            if "normalized_id" in ent:
+                                normalized_ids.append(ent["normalized_id"])
+                            else: 
+                                normalized_ids.append("CUI-less")
+                            if self.data_source=="clinical trials":
+                                field.append(row["criteria"])
+                            elif self.data_source=="patient notes":
+                                field.append(row["field"])
+                    else:
+                        continue
+        return pd.DataFrame({
+                            'nct_id': _ids,
+                            'text': sentences,
+                            'entity_text': entities_texts,
+                            'entity_group': entities_groups,
+                            'normalized_id': normalized_ids,
+                            'field' : field,
+                            "is_negated" : is_negated,
+                        })
+            
+    def save_output(self, df, output_filepath):
+        df.to_csv(output_filepath, index=False)
+
+    def __call__(self):
+        all_df = self.data_loader(self.id_list)
+
+        def process_dataframe(df):
+            output_filepath = OUTPUT_FILEPATH_CT + df["id"].iloc[0] + ".csv"
+            if not os.path.exists(output_filepath):
+                result_df = self.recognize_entities(df)
+                if self.data_source == "clinical trials":
+                    self.save_output(result_df, output_filepath)
+                return result_df
+
+        parallel_runner = ParallelExecutor(n_jobs=self.n_jobs)(total=len(self.id_list))
+        
+        parallel_runner(delayed(process_dataframe)(df) for df in all_df)
+        
+        return 
+
+if __name__ == "__main__":
+    # Load the list of NCT IDs
+    folder_path = '/home/mabdallah/TrialMatchAI/data/trials_xmls' # Replace this with the path to your folder
+    file_names = []
+    # List all files in the folder
+    for file in os.listdir(folder_path):
+        if os.path.isfile(os.path.join(folder_path, file)):
+            file_name, file_extension = os.path.splitext(file)
+            file_names.append(file_name)
+    nct_ids = file_names
+    reco = EntityRecognizer(n_jobs=5, id_list=nct_ids, data_source="clinical trials")
+    entities = reco()
+    # # Load the list of patient IDs
+    # pat_ids = pd.read_csv("../data/patient_ids.csv")
+    # pat_ids = pat_ids["id"].tolist()
+    # reco = EntityRecognizer(n_jobs=50, id_list=pat_ids, data_source="patient notes")
+    # entities = reco()
+    # entities.to_csv("../data/ner_patients_clinical_notes/entities_parsed.csv", index = False)
\ No newline at end of file