--- a +++ b/wrapper_functions/EHRKit.py @@ -0,0 +1,161 @@ +from scispacy_functions import get_abbreviations, get_hyponyms, get_linked_entities, get_named_entities +from transformer_functions import get_translation, get_supported_translation_languages, get_single_summary, get_multi_summary_joint +from utils import get_sents_stanza, get_multiple_sents_stanza, get_sents_pyrush, get_sents_scispacy +from multi_doc_functions import get_clusters, get_similar_documents +import numpy as np +from stanza_functions import ( + get_named_entities_stanza_biomed, + get_sents_stanza_biomed, + get_tokens_stanza_biomed, + get_part_of_speech_and_morphological_features, + get_lemmas_stanza_biomed, + get_dependency_stanza_biomed +) + +class EHRKit: + """ + EHRKit is the main class of this toolkit. An EHRKit object stores textual records and default models for various tasks. + Different tasks can be called from an EHRKit object to perform tasks on the stored textual records. + + Args: + main_record (str): main textual record + support_records (list): list of auxiliary textual records (used in multi-document tasks) + scispacy_model (str): default model for scispacy tasks + bert_model (str): default model for pre-trained transformers + marian_model (str): default model for translation + """ + def __init__(self, + main_record="", + supporting_records=[], + scispacy_model="en_core_sci_sm", + bert_model="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", + marian_model="Helsinki-NLP/opus-mt-en-ROMANCE"): + + self.main_record = main_record + self.supporting_records = supporting_records + self.scispacy_model = scispacy_model + self.bert_model = bert_model + self.marian_model = marian_model + + ''' Functions for manipulating records and default models ''' + def update_and_delete_main_record(self, main_record): + """ + update current main_record, current main_record will be deleted + """ + if main_record: + self.main_record = main_record + else: + raise TypeError('Invalid type for main_record') + + def update_and_keep_main_record(self, main_record): + """ + update current main_record, current main_record will be added to the end of supporting_records + """ + if main_record: + self.supporting_records.append(self.main_record) + self.main_record = main_record + else: + raise TypeError('Invalid type for main_record') + + def replace_supporting_records(self, supporting_records): + """ + replace current supporting records + """ + self.supporting_records = supporting_records + + def add_supporting_records(self, supporting_records): + """ + add additional supporting records to existing supporting records + """ + self.supporting_records.extend(supporting_records) + + def update_scispacy_model(self, scispacy_model): + self.scispacy_model = scispacy_model + + def update_bert_model(self, bert_model): + self.bert_model = bert_model + + def update_marian_model(self, marian_model): + self.marian_model = marian_model + + ''' Functions for textual record processing ''' + def get_abbreviations(self): + abbreviations = get_abbreviations(self.scispacy_model, self.main_record) + return abbreviations + + def get_hyponyms(self): + hyponyms = get_hyponyms(self.scispacy_model, self.main_record) + return hyponyms + + def get_linked_entities(self): + linked_entities = get_linked_entities(self.scispacy_model, self.main_record) + return linked_entities + + def get_named_entities(self, tool='scispacy'): + if tool == 'scispacy': + named_entities = get_named_entities(self.scispacy_model, self.main_record) + elif tool == 'stanza': + named_entities = get_named_entities_stanza_biomed(self.main_record) + return named_entities + + def get_translation(self, target_language='Spanish'): + translation = get_translation(self.main_record, self.marian_model, target_language) + return translation + + def get_supported_translation_languages(self): + return get_supported_translation_languages() + + def get_sentences(self, tool='stanza'): + if tool == 'pyrush': + sents = get_sents_pyrush(self.main_record) + sents = [self.main_record[sent.begin:sent.end] for sent in sents] + elif tool == 'stanza': + sents = get_sents_stanza(self.main_record) + elif tool == 'scispacy': + sents = get_sents_scispacy(self.main_record) + elif tool == 'stanza-biomed': + sents = get_sents_stanza_biomed(self.main_record) + return sents + + def get_tokens(self, tool='stanza-biomed'): + if tool == 'stanza-biomed': + tokens = get_tokens_stanza_biomed(self.main_record) + return tokens + + def get_pos_tags(self, tool='stanza-biomed'): + if tool == 'stanza-biomed': + tags = get_part_of_speech_and_morphological_features(self.main_record) + return tags + + def get_lemmas(self, tool='stanza-biomed'): + if tool == 'stanza-biomed': + lemmas = get_lemmas_stanza_biomed(self.main_record) + return lemmas + + def get_dependency(self, tool='stanza-biomed'): + if tool == 'stanza-biomed': + dependencies = get_dependency_stanza_biomed(self.main_record) + return dependencies + + def get_clusters(self, k=2): + # combine main record and candidate records for clustering + docs = [self.main_record] + self.supporting_records + clusters = get_clusters(self.bert_model, docs, k) + return clusters + + def get_similar_documents(self, k=2): + query_note = self.main_record + candidate_notes = self.supporting_records + # ids of candidates + candidates = np.array(range(len(candidate_notes))) + similar_docs = get_similar_documents(self.bert_model, query_note, candidate_notes, candidates, top_k=k) + return similar_docs + + def get_single_record_summary(self): + summary = get_single_summary(self.main_record) + return summary + + def get_multi_record_summary(self): + docs = [self.main_record] + self.supporting_records + summary = get_multi_summary_joint(docs) + return summary