a b/wrapper_functions/EHRKit.py
1
from scispacy_functions import get_abbreviations, get_hyponyms, get_linked_entities, get_named_entities
2
from transformer_functions import get_translation, get_supported_translation_languages, get_single_summary, get_multi_summary_joint
3
from utils import get_sents_stanza, get_multiple_sents_stanza, get_sents_pyrush, get_sents_scispacy
4
from multi_doc_functions import get_clusters, get_similar_documents
5
import numpy as np
6
from stanza_functions import (
7
    get_named_entities_stanza_biomed,
8
    get_sents_stanza_biomed,
9
    get_tokens_stanza_biomed,
10
    get_part_of_speech_and_morphological_features,
11
    get_lemmas_stanza_biomed,
12
    get_dependency_stanza_biomed
13
)
14
15
class EHRKit:
16
    """
17
    EHRKit is the main class of this toolkit. An EHRKit object stores textual records and default models for various tasks.
18
    Different tasks can be called from an EHRKit object to perform tasks on the stored textual records.
19
20
    Args:
21
        main_record (str): main textual record
22
        support_records (list): list of auxiliary textual records (used in multi-document tasks)
23
        scispacy_model (str): default model for scispacy tasks
24
        bert_model (str): default model for pre-trained transformers
25
        marian_model (str): default model for translation
26
    """
27
    def __init__(self,
28
                 main_record="",
29
                 supporting_records=[],
30
                 scispacy_model="en_core_sci_sm",
31
                 bert_model="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
32
                 marian_model="Helsinki-NLP/opus-mt-en-ROMANCE"):
33
34
        self.main_record = main_record
35
        self.supporting_records = supporting_records
36
        self.scispacy_model = scispacy_model
37
        self.bert_model = bert_model
38
        self.marian_model = marian_model
39
40
    ''' Functions for manipulating records and default models '''
41
    def update_and_delete_main_record(self, main_record):
42
        """
43
        update current main_record, current main_record will be deleted
44
        """
45
        if main_record:
46
            self.main_record = main_record
47
        else:
48
            raise TypeError('Invalid type for main_record')
49
50
    def update_and_keep_main_record(self, main_record):
51
        """
52
        update current main_record, current main_record will be added to the end of supporting_records
53
        """
54
        if main_record:
55
            self.supporting_records.append(self.main_record)
56
            self.main_record = main_record
57
        else:
58
            raise TypeError('Invalid type for main_record')
59
60
    def replace_supporting_records(self, supporting_records):
61
        """
62
        replace current supporting records
63
        """
64
        self.supporting_records = supporting_records
65
66
    def add_supporting_records(self, supporting_records):
67
        """
68
        add additional supporting records to existing supporting records
69
        """
70
        self.supporting_records.extend(supporting_records)
71
72
    def update_scispacy_model(self, scispacy_model):
73
        self.scispacy_model = scispacy_model
74
75
    def update_bert_model(self, bert_model):
76
        self.bert_model = bert_model
77
78
    def update_marian_model(self, marian_model):
79
        self.marian_model = marian_model
80
81
    ''' Functions for textual record processing '''
82
    def get_abbreviations(self):
83
        abbreviations = get_abbreviations(self.scispacy_model, self.main_record)
84
        return abbreviations
85
86
    def get_hyponyms(self):
87
        hyponyms = get_hyponyms(self.scispacy_model, self.main_record)
88
        return hyponyms
89
90
    def get_linked_entities(self):
91
        linked_entities = get_linked_entities(self.scispacy_model, self.main_record)
92
        return linked_entities
93
94
    def get_named_entities(self, tool='scispacy'):
95
        if tool ==  'scispacy':
96
            named_entities = get_named_entities(self.scispacy_model, self.main_record)
97
        elif tool == 'stanza':
98
            named_entities = get_named_entities_stanza_biomed(self.main_record)
99
        return named_entities
100
101
    def get_translation(self, target_language='Spanish'):
102
        translation = get_translation(self.main_record, self.marian_model, target_language)
103
        return translation
104
105
    def get_supported_translation_languages(self):
106
        return get_supported_translation_languages()
107
108
    def get_sentences(self, tool='stanza'):
109
        if tool == 'pyrush':
110
            sents = get_sents_pyrush(self.main_record)
111
            sents = [self.main_record[sent.begin:sent.end] for sent in sents]
112
        elif tool == 'stanza':
113
            sents = get_sents_stanza(self.main_record)
114
        elif tool == 'scispacy':
115
            sents = get_sents_scispacy(self.main_record)
116
        elif tool == 'stanza-biomed':
117
            sents = get_sents_stanza_biomed(self.main_record)
118
        return sents
119
120
    def get_tokens(self, tool='stanza-biomed'):
121
        if tool == 'stanza-biomed':
122
            tokens = get_tokens_stanza_biomed(self.main_record)
123
        return tokens
124
125
    def get_pos_tags(self, tool='stanza-biomed'):
126
        if tool ==  'stanza-biomed':
127
            tags = get_part_of_speech_and_morphological_features(self.main_record)
128
        return tags
129
130
    def get_lemmas(self, tool='stanza-biomed'):
131
        if tool == 'stanza-biomed':
132
            lemmas = get_lemmas_stanza_biomed(self.main_record)
133
        return lemmas
134
135
    def get_dependency(self, tool='stanza-biomed'):
136
        if tool == 'stanza-biomed':
137
            dependencies = get_dependency_stanza_biomed(self.main_record)
138
        return dependencies
139
140
    def get_clusters(self, k=2):
141
        # combine main record and candidate records for clustering
142
        docs = [self.main_record] + self.supporting_records
143
        clusters = get_clusters(self.bert_model, docs, k)
144
        return clusters
145
146
    def get_similar_documents(self, k=2):
147
        query_note = self.main_record
148
        candidate_notes = self.supporting_records
149
        # ids of candidates
150
        candidates = np.array(range(len(candidate_notes)))
151
        similar_docs = get_similar_documents(self.bert_model, query_note, candidate_notes, candidates, top_k=k)
152
        return similar_docs
153
154
    def get_single_record_summary(self):
155
        summary = get_single_summary(self.main_record)
156
        return summary
157
158
    def get_multi_record_summary(self):
159
        docs = [self.main_record] + self.supporting_records
160
        summary = get_multi_summary_joint(docs)
161
        return summary