|
a |
|
b/wrapper_functions/EHRKit.py |
|
|
1 |
from scispacy_functions import get_abbreviations, get_hyponyms, get_linked_entities, get_named_entities |
|
|
2 |
from transformer_functions import get_translation, get_supported_translation_languages, get_single_summary, get_multi_summary_joint |
|
|
3 |
from utils import get_sents_stanza, get_multiple_sents_stanza, get_sents_pyrush, get_sents_scispacy |
|
|
4 |
from multi_doc_functions import get_clusters, get_similar_documents |
|
|
5 |
import numpy as np |
|
|
6 |
from stanza_functions import ( |
|
|
7 |
get_named_entities_stanza_biomed, |
|
|
8 |
get_sents_stanza_biomed, |
|
|
9 |
get_tokens_stanza_biomed, |
|
|
10 |
get_part_of_speech_and_morphological_features, |
|
|
11 |
get_lemmas_stanza_biomed, |
|
|
12 |
get_dependency_stanza_biomed |
|
|
13 |
) |
|
|
14 |
|
|
|
15 |
class EHRKit: |
|
|
16 |
""" |
|
|
17 |
EHRKit is the main class of this toolkit. An EHRKit object stores textual records and default models for various tasks. |
|
|
18 |
Different tasks can be called from an EHRKit object to perform tasks on the stored textual records. |
|
|
19 |
|
|
|
20 |
Args: |
|
|
21 |
main_record (str): main textual record |
|
|
22 |
support_records (list): list of auxiliary textual records (used in multi-document tasks) |
|
|
23 |
scispacy_model (str): default model for scispacy tasks |
|
|
24 |
bert_model (str): default model for pre-trained transformers |
|
|
25 |
marian_model (str): default model for translation |
|
|
26 |
""" |
|
|
27 |
def __init__(self, |
|
|
28 |
main_record="", |
|
|
29 |
supporting_records=[], |
|
|
30 |
scispacy_model="en_core_sci_sm", |
|
|
31 |
bert_model="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", |
|
|
32 |
marian_model="Helsinki-NLP/opus-mt-en-ROMANCE"): |
|
|
33 |
|
|
|
34 |
self.main_record = main_record |
|
|
35 |
self.supporting_records = supporting_records |
|
|
36 |
self.scispacy_model = scispacy_model |
|
|
37 |
self.bert_model = bert_model |
|
|
38 |
self.marian_model = marian_model |
|
|
39 |
|
|
|
40 |
''' Functions for manipulating records and default models ''' |
|
|
41 |
def update_and_delete_main_record(self, main_record): |
|
|
42 |
""" |
|
|
43 |
update current main_record, current main_record will be deleted |
|
|
44 |
""" |
|
|
45 |
if main_record: |
|
|
46 |
self.main_record = main_record |
|
|
47 |
else: |
|
|
48 |
raise TypeError('Invalid type for main_record') |
|
|
49 |
|
|
|
50 |
def update_and_keep_main_record(self, main_record): |
|
|
51 |
""" |
|
|
52 |
update current main_record, current main_record will be added to the end of supporting_records |
|
|
53 |
""" |
|
|
54 |
if main_record: |
|
|
55 |
self.supporting_records.append(self.main_record) |
|
|
56 |
self.main_record = main_record |
|
|
57 |
else: |
|
|
58 |
raise TypeError('Invalid type for main_record') |
|
|
59 |
|
|
|
60 |
def replace_supporting_records(self, supporting_records): |
|
|
61 |
""" |
|
|
62 |
replace current supporting records |
|
|
63 |
""" |
|
|
64 |
self.supporting_records = supporting_records |
|
|
65 |
|
|
|
66 |
def add_supporting_records(self, supporting_records): |
|
|
67 |
""" |
|
|
68 |
add additional supporting records to existing supporting records |
|
|
69 |
""" |
|
|
70 |
self.supporting_records.extend(supporting_records) |
|
|
71 |
|
|
|
72 |
def update_scispacy_model(self, scispacy_model): |
|
|
73 |
self.scispacy_model = scispacy_model |
|
|
74 |
|
|
|
75 |
def update_bert_model(self, bert_model): |
|
|
76 |
self.bert_model = bert_model |
|
|
77 |
|
|
|
78 |
def update_marian_model(self, marian_model): |
|
|
79 |
self.marian_model = marian_model |
|
|
80 |
|
|
|
81 |
''' Functions for textual record processing ''' |
|
|
82 |
def get_abbreviations(self): |
|
|
83 |
abbreviations = get_abbreviations(self.scispacy_model, self.main_record) |
|
|
84 |
return abbreviations |
|
|
85 |
|
|
|
86 |
def get_hyponyms(self): |
|
|
87 |
hyponyms = get_hyponyms(self.scispacy_model, self.main_record) |
|
|
88 |
return hyponyms |
|
|
89 |
|
|
|
90 |
def get_linked_entities(self): |
|
|
91 |
linked_entities = get_linked_entities(self.scispacy_model, self.main_record) |
|
|
92 |
return linked_entities |
|
|
93 |
|
|
|
94 |
def get_named_entities(self, tool='scispacy'): |
|
|
95 |
if tool == 'scispacy': |
|
|
96 |
named_entities = get_named_entities(self.scispacy_model, self.main_record) |
|
|
97 |
elif tool == 'stanza': |
|
|
98 |
named_entities = get_named_entities_stanza_biomed(self.main_record) |
|
|
99 |
return named_entities |
|
|
100 |
|
|
|
101 |
def get_translation(self, target_language='Spanish'): |
|
|
102 |
translation = get_translation(self.main_record, self.marian_model, target_language) |
|
|
103 |
return translation |
|
|
104 |
|
|
|
105 |
def get_supported_translation_languages(self): |
|
|
106 |
return get_supported_translation_languages() |
|
|
107 |
|
|
|
108 |
def get_sentences(self, tool='stanza'): |
|
|
109 |
if tool == 'pyrush': |
|
|
110 |
sents = get_sents_pyrush(self.main_record) |
|
|
111 |
sents = [self.main_record[sent.begin:sent.end] for sent in sents] |
|
|
112 |
elif tool == 'stanza': |
|
|
113 |
sents = get_sents_stanza(self.main_record) |
|
|
114 |
elif tool == 'scispacy': |
|
|
115 |
sents = get_sents_scispacy(self.main_record) |
|
|
116 |
elif tool == 'stanza-biomed': |
|
|
117 |
sents = get_sents_stanza_biomed(self.main_record) |
|
|
118 |
return sents |
|
|
119 |
|
|
|
120 |
def get_tokens(self, tool='stanza-biomed'): |
|
|
121 |
if tool == 'stanza-biomed': |
|
|
122 |
tokens = get_tokens_stanza_biomed(self.main_record) |
|
|
123 |
return tokens |
|
|
124 |
|
|
|
125 |
def get_pos_tags(self, tool='stanza-biomed'): |
|
|
126 |
if tool == 'stanza-biomed': |
|
|
127 |
tags = get_part_of_speech_and_morphological_features(self.main_record) |
|
|
128 |
return tags |
|
|
129 |
|
|
|
130 |
def get_lemmas(self, tool='stanza-biomed'): |
|
|
131 |
if tool == 'stanza-biomed': |
|
|
132 |
lemmas = get_lemmas_stanza_biomed(self.main_record) |
|
|
133 |
return lemmas |
|
|
134 |
|
|
|
135 |
def get_dependency(self, tool='stanza-biomed'): |
|
|
136 |
if tool == 'stanza-biomed': |
|
|
137 |
dependencies = get_dependency_stanza_biomed(self.main_record) |
|
|
138 |
return dependencies |
|
|
139 |
|
|
|
140 |
def get_clusters(self, k=2): |
|
|
141 |
# combine main record and candidate records for clustering |
|
|
142 |
docs = [self.main_record] + self.supporting_records |
|
|
143 |
clusters = get_clusters(self.bert_model, docs, k) |
|
|
144 |
return clusters |
|
|
145 |
|
|
|
146 |
def get_similar_documents(self, k=2): |
|
|
147 |
query_note = self.main_record |
|
|
148 |
candidate_notes = self.supporting_records |
|
|
149 |
# ids of candidates |
|
|
150 |
candidates = np.array(range(len(candidate_notes))) |
|
|
151 |
similar_docs = get_similar_documents(self.bert_model, query_note, candidate_notes, candidates, top_k=k) |
|
|
152 |
return similar_docs |
|
|
153 |
|
|
|
154 |
def get_single_record_summary(self): |
|
|
155 |
summary = get_single_summary(self.main_record) |
|
|
156 |
return summary |
|
|
157 |
|
|
|
158 |
def get_multi_record_summary(self): |
|
|
159 |
docs = [self.main_record] + self.supporting_records |
|
|
160 |
summary = get_multi_summary_joint(docs) |
|
|
161 |
return summary |