|
a |
|
b/scripts/utils.py |
|
|
1 |
import re |
|
|
2 |
import os |
|
|
3 |
import pickle |
|
|
4 |
|
|
|
5 |
import spacy |
|
|
6 |
from spacy import displacy |
|
|
7 |
|
|
|
8 |
import numpy as np |
|
|
9 |
from tensorflow.keras.preprocessing.sequence import pad_sequences |
|
|
10 |
|
|
|
11 |
import nltk |
|
|
12 |
nltk.download('punkt') |
|
|
13 |
nltk.download('stopwords') |
|
|
14 |
from nltk.corpus import stopwords |
|
|
15 |
|
|
|
16 |
STOP_WORDS = stopwords.words('english') |
|
|
17 |
|
|
|
18 |
# Load the tokenizer from file |
|
|
19 |
with open('../data/tokenizer.pickle', 'rb') as handle: |
|
|
20 |
tokenizer = pickle.load(handle) |
|
|
21 |
|
|
|
22 |
def load_data(data_dir): |
|
|
23 |
data = np.load(os.path.join(data_dir, 'data.npz'), allow_pickle=True) |
|
|
24 |
|
|
|
25 |
train_sequences_padded = data['train_sequences_padded'] |
|
|
26 |
train_labels = data['train_labels'] |
|
|
27 |
|
|
|
28 |
val_sequences_padded = data['val_sequences_padded'] |
|
|
29 |
val_labels = data['val_labels'] |
|
|
30 |
|
|
|
31 |
test_sequences_padded = data['test_sequences_padded'] |
|
|
32 |
test_labels = data['test_labels'] |
|
|
33 |
|
|
|
34 |
label_to_index = data['label_to_index'].item() # use .item() to convert the numpy array to a Python dictionary |
|
|
35 |
|
|
|
36 |
index_to_label = data['index_to_label'].item() |
|
|
37 |
|
|
|
38 |
return (train_sequences_padded, train_labels), (val_sequences_padded, val_labels), ( |
|
|
39 |
test_sequences_padded, test_labels), label_to_index, index_to_label |
|
|
40 |
|
|
|
41 |
|
|
|
42 |
def clean_word(word): |
|
|
43 |
""" |
|
|
44 |
Cleans a word by removing non-alphanumeric characters and extra whitespaces, |
|
|
45 |
converting it to lowercase, and checking if it is a stopword. |
|
|
46 |
|
|
|
47 |
Args: |
|
|
48 |
- word (str): the word to clean |
|
|
49 |
|
|
|
50 |
Returns: |
|
|
51 |
- str: the cleaned word, or an empty string if it is a stopword |
|
|
52 |
""" |
|
|
53 |
# remove non-alphanumeric characters and extra whitespaces |
|
|
54 |
word = re.sub(r'[^\w\s]', '', word) |
|
|
55 |
word = re.sub(r'\s+', ' ', word) |
|
|
56 |
|
|
|
57 |
# convert to lowercase |
|
|
58 |
word = word.lower() |
|
|
59 |
|
|
|
60 |
if word not in STOP_WORDS: |
|
|
61 |
return word |
|
|
62 |
|
|
|
63 |
return '' |
|
|
64 |
|
|
|
65 |
def tokenize_text(text): |
|
|
66 |
""" |
|
|
67 |
Tokenizes a text into a list of cleaned words. |
|
|
68 |
|
|
|
69 |
Args: |
|
|
70 |
- text (str): the text to tokenize |
|
|
71 |
|
|
|
72 |
Returns: |
|
|
73 |
- tokens (list of str): the list of cleaned words |
|
|
74 |
- start_end_ranges (list of tuples): the start and end character positions for each token |
|
|
75 |
""" |
|
|
76 |
regex_match = r'[^\s\u200a\-\u2010-\u2015\u2212\uff0d]+' # r'[^\s\u200a\-\—\–]+' |
|
|
77 |
tokens = [] |
|
|
78 |
start_end_ranges = [] |
|
|
79 |
# Tokenize the sentences in the text |
|
|
80 |
sentences = nltk.sent_tokenize(text) |
|
|
81 |
|
|
|
82 |
start = 0 |
|
|
83 |
for sentence in sentences: |
|
|
84 |
|
|
|
85 |
sentence_tokens = re.findall(regex_match, sentence) |
|
|
86 |
curr_sent_tokens = [] |
|
|
87 |
curr_sent_ranges = [] |
|
|
88 |
|
|
|
89 |
for word in sentence_tokens: |
|
|
90 |
word = clean_word(word) |
|
|
91 |
if word.strip(): |
|
|
92 |
start = text.lower().find(word, start) |
|
|
93 |
end = start + len(word) |
|
|
94 |
curr_sent_ranges.append((start, end)) |
|
|
95 |
curr_sent_tokens.append(word) |
|
|
96 |
start = end |
|
|
97 |
if len(curr_sent_tokens) > 0: |
|
|
98 |
tokens.append(curr_sent_tokens) |
|
|
99 |
start_end_ranges.append(curr_sent_ranges) |
|
|
100 |
|
|
|
101 |
return tokens, start_end_ranges |
|
|
102 |
|
|
|
103 |
# def tokenize_text(text): |
|
|
104 |
# """ |
|
|
105 |
# Tokenizes a text into a list of cleaned words. |
|
|
106 |
# |
|
|
107 |
# Args: |
|
|
108 |
# - text (str): the text to tokenize |
|
|
109 |
# |
|
|
110 |
# Returns: |
|
|
111 |
# - list of str: the list of cleaned words |
|
|
112 |
# """ |
|
|
113 |
# regex_match = r'[^\s\u200a\-\u2010-\u2015\u2212\uff0d]+' # r'[^\s\u200a\-\—\–]+' |
|
|
114 |
# tokens = [] |
|
|
115 |
# for sentence in text.split('\n'): |
|
|
116 |
# sentence_tokens = re.findall(regex_match, sentence) |
|
|
117 |
# for word in sentence_tokens: |
|
|
118 |
# word = clean_word(word) |
|
|
119 |
# if word.strip(): |
|
|
120 |
# tokens.append(word) |
|
|
121 |
# return tokens |
|
|
122 |
|
|
|
123 |
|
|
|
124 |
def predict(text, model, index_to_label, acronyms_to_entities, MAX_LENGTH): |
|
|
125 |
""" |
|
|
126 |
Predicts named entities in a text using a trained NER model. |
|
|
127 |
|
|
|
128 |
Args: |
|
|
129 |
- text (str): the text to predict named entities in |
|
|
130 |
- model: the trained NER model |
|
|
131 |
- tokenizer: the trained tokenizer used for the model |
|
|
132 |
- index_to_label (list of str): a list mapping each index in the predicted sequence to a named entity label |
|
|
133 |
- acronyms_to_entities (dict): a dictionary mapping acronyms to their corresponding named entity labels |
|
|
134 |
- MAX_LENGTH (int): the maximum sequence length for the model |
|
|
135 |
|
|
|
136 |
Returns: |
|
|
137 |
- None |
|
|
138 |
""" |
|
|
139 |
|
|
|
140 |
tokens, start_end_ranges = tokenize_text(text) |
|
|
141 |
all_tokens = [] |
|
|
142 |
all_ranges = [] |
|
|
143 |
for sent_tokens, sent_ranges in zip(tokens, start_end_ranges): |
|
|
144 |
for token, start_end in zip(sent_tokens, sent_ranges): |
|
|
145 |
start, end = start_end[0], start_end[1] |
|
|
146 |
all_tokens.append(token) |
|
|
147 |
all_ranges.append((start, end)) |
|
|
148 |
|
|
|
149 |
sequence = tokenizer.texts_to_sequences([' '.join(token for token in all_tokens)]) |
|
|
150 |
padded_sequence = pad_sequences(sequence, maxlen=MAX_LENGTH, padding='post') |
|
|
151 |
|
|
|
152 |
# Make the prediction |
|
|
153 |
prediction = model.predict(np.array(padded_sequence)) |
|
|
154 |
|
|
|
155 |
# Decode the prediction |
|
|
156 |
predicted_labels = np.argmax(prediction, axis=-1) |
|
|
157 |
predicted_labels = [index_to_label[i] for i in predicted_labels[0]] |
|
|
158 |
|
|
|
159 |
entities = [] |
|
|
160 |
start_char = 0 |
|
|
161 |
for i, (token, label, start_end_range) in enumerate(zip(all_tokens, predicted_labels, all_ranges)): |
|
|
162 |
|
|
|
163 |
start = start_end_range[0] |
|
|
164 |
end = start_end_range[1] |
|
|
165 |
|
|
|
166 |
if label != 'O': |
|
|
167 |
entity_type = acronyms_to_entities[label[2:]] |
|
|
168 |
entity = (start, end, entity_type) |
|
|
169 |
entities.append(entity) |
|
|
170 |
|
|
|
171 |
# Print the predicted named entities |
|
|
172 |
print("Predicted Named Entities:") |
|
|
173 |
for i in range(len(all_tokens)): |
|
|
174 |
if predicted_labels[i] == 'O': |
|
|
175 |
print(f"{all_tokens[i]}: {predicted_labels[i]}") |
|
|
176 |
else: |
|
|
177 |
print(f"{all_tokens[i]}: {acronyms_to_entities[predicted_labels[i][2:]]}") |
|
|
178 |
|
|
|
179 |
display_pred(text, entities) |
|
|
180 |
|
|
|
181 |
def display_pred(text, entities): |
|
|
182 |
nlp = spacy.load("en_core_web_sm", disable=['ner']) |
|
|
183 |
# Generate the entities in Spacy format |
|
|
184 |
doc = nlp(text) |
|
|
185 |
# Add the predicted named entities to the Doc object |
|
|
186 |
for start, end, label in entities: |
|
|
187 |
span = doc.char_span(start, end, label=label) |
|
|
188 |
if span is not None: |
|
|
189 |
doc.ents += tuple([span]) |
|
|
190 |
|
|
|
191 |
colors = {"Activity": "#f9d5e5", |
|
|
192 |
"Administration": "#f7a399", |
|
|
193 |
"Age": "#f6c3d0", |
|
|
194 |
"Area": "#fde2e4", |
|
|
195 |
"Biological_attribute": "#d5f5e3", |
|
|
196 |
"Biological_structure": "#9ddfd3", |
|
|
197 |
"Clinical_event": "#77c5d5", |
|
|
198 |
"Color": "#a0ced9", |
|
|
199 |
"Coreference": "#e3b5a4", |
|
|
200 |
"Date": "#f1f0d2", |
|
|
201 |
"Detailed_description": "#ffb347", |
|
|
202 |
"Diagnostic_procedure": "#c5b4e3", |
|
|
203 |
"Disease_disorder": "#c4b7ea", |
|
|
204 |
"Distance": "#bde0fe", |
|
|
205 |
"Dosage": "#b9e8d8", |
|
|
206 |
"Duration": "#ffdfba", |
|
|
207 |
"Family_history": "#e6ccb2", |
|
|
208 |
"Frequency": "#e9d8a6", |
|
|
209 |
"Height": "#f2eecb", |
|
|
210 |
"History": "#e2f0cb", |
|
|
211 |
"Lab_value": "#f4b3c2", |
|
|
212 |
"Mass": "#f4c4c3", |
|
|
213 |
"Medication": "#f9d5e5", |
|
|
214 |
"Nonbiological_location": "#f7a399", |
|
|
215 |
"Occupation": "#f6c3d0", |
|
|
216 |
"Other_entity": "#d5f5e3", |
|
|
217 |
"Other_event": "#9ddfd3", |
|
|
218 |
"Outcome": "#77c5d5", |
|
|
219 |
"Personal_background": "#a0ced9", |
|
|
220 |
"Qualitative_concept": "#e3b5a4", |
|
|
221 |
"Quantitative_concept": "#f1f0d2", |
|
|
222 |
"Severity": "#ffb347", |
|
|
223 |
"Sex": "#c5b4e3", |
|
|
224 |
"Shape": "#c4b7ea", |
|
|
225 |
"Sign_symptom": "#bde0fe", |
|
|
226 |
"Subject": "#b9e8d8", |
|
|
227 |
"Texture": "#ffdfba", |
|
|
228 |
"Therapeutic_procedure": "#e6ccb2", |
|
|
229 |
"Time": "#e9d8a6", |
|
|
230 |
"Volume": "#f2eecb", |
|
|
231 |
"Weight": "#e2f0cb"} |
|
|
232 |
options = {"compact": True, "bg": "#F8F8F8", |
|
|
233 |
"ents": list(colors.keys()), |
|
|
234 |
"colors": colors} |
|
|
235 |
|
|
|
236 |
# Generate the HTML visualization |
|
|
237 |
html = displacy.render(doc, style="ent", options=options) |
|
|
238 |
|
|
|
239 |
# def predict(text, model, tokenizer, index_to_label, acronyms_to_entities, MAX_LENGTH): |
|
|
240 |
# """ |
|
|
241 |
# Predicts named entities in a text using a trained NER model. |
|
|
242 |
# |
|
|
243 |
# Args: |
|
|
244 |
# - text (str): the text to predict named entities in |
|
|
245 |
# - model: the trained NER model |
|
|
246 |
# - tokenizer: the trained tokenizer used for the model |
|
|
247 |
# - index_to_label (list of str): a list mapping each index in the predicted sequence to a named entity label |
|
|
248 |
# - acronyms_to_entities (dict): a dictionary mapping acronyms to their corresponding named entity labels |
|
|
249 |
# - MAX_LENGTH (int): the maximum sequence length for the model |
|
|
250 |
# |
|
|
251 |
# Returns: |
|
|
252 |
# - None |
|
|
253 |
# """ |
|
|
254 |
# |
|
|
255 |
# tokens = tokenize_text(text) |
|
|
256 |
# sequence = tokenizer.texts_to_sequences([' '.join(token for token in tokens)]) |
|
|
257 |
# padded_sequence = pad_sequences(sequence, maxlen=MAX_LENGTH, padding='post') |
|
|
258 |
# |
|
|
259 |
# # Make the prediction |
|
|
260 |
# prediction = model.predict(np.array(padded_sequence)) |
|
|
261 |
# |
|
|
262 |
# # Decode the prediction |
|
|
263 |
# predicted_labels = np.argmax(prediction, axis=-1) |
|
|
264 |
# predicted_labels = [index_to_label[i] for i in predicted_labels[0]] |
|
|
265 |
# |
|
|
266 |
# # Print the predicted named entities |
|
|
267 |
# print("Predicted Named Entities:") |
|
|
268 |
# for i in range(len(tokens)): |
|
|
269 |
# if predicted_labels[i] == 'O': |
|
|
270 |
# print(f"{tokens[i]}: {predicted_labels[i]}") |
|
|
271 |
# else: |
|
|
272 |
# print(f"{tokens[i]}: {acronyms_to_entities[predicted_labels[i][2:]]}") |
|
|
273 |
# |
|
|
274 |
|
|
|
275 |
def predict_multi_line_text(text, model, index_to_label, acronyms_to_entities, MAX_LENGTH): |
|
|
276 |
|
|
|
277 |
# sentences = re.split(r' *[\.\?!][\'"\)\]]* *', text) |
|
|
278 |
# sent_tokens = [] |
|
|
279 |
# sent_start_end = [] |
|
|
280 |
sequences = [] |
|
|
281 |
|
|
|
282 |
sent_tokens, sent_start_end = tokenize_text(text) |
|
|
283 |
|
|
|
284 |
for i in range(len(sent_tokens)): |
|
|
285 |
sequence = tokenizer.texts_to_sequences([' '.join(token for token in sent_tokens[i])]) |
|
|
286 |
sequences.extend(sequence) |
|
|
287 |
|
|
|
288 |
# for sentence in sentences: |
|
|
289 |
# tokens, start_end_ranges = tokenize_text(sentence) |
|
|
290 |
# sequence = tokenizer.texts_to_sequences([' '.join(token for token in tokens)]) |
|
|
291 |
# sequences.append(sequence[0]) |
|
|
292 |
# sent_tokens.append(tokens) |
|
|
293 |
# sent_start_end.append(start_end_ranges) |
|
|
294 |
|
|
|
295 |
padded_sequence = pad_sequences(sequences, maxlen=MAX_LENGTH, padding='post') |
|
|
296 |
|
|
|
297 |
# Make the prediction |
|
|
298 |
prediction = model.predict(np.array(padded_sequence)) |
|
|
299 |
|
|
|
300 |
# Decode the prediction |
|
|
301 |
predicted_labels = np.argmax(prediction, axis=-1) |
|
|
302 |
|
|
|
303 |
predicted_labels = [ |
|
|
304 |
[index_to_label[i] for i in sent_predicted_labels] |
|
|
305 |
for sent_predicted_labels in predicted_labels |
|
|
306 |
] |
|
|
307 |
|
|
|
308 |
entities = [] |
|
|
309 |
start_char = 0 |
|
|
310 |
|
|
|
311 |
for tokens, sent_pred_labels, start_end_ranges in zip(sent_tokens, predicted_labels, sent_start_end): |
|
|
312 |
|
|
|
313 |
for i, (token, label, start_end_range) in enumerate(zip(tokens, sent_pred_labels, start_end_ranges)): |
|
|
314 |
start = start_end_range[0] |
|
|
315 |
end = start_end_range[1] |
|
|
316 |
|
|
|
317 |
if label != 'O': |
|
|
318 |
entity_type = acronyms_to_entities[label[2:]] |
|
|
319 |
entity = (start, end, entity_type) |
|
|
320 |
entities.append(entity) |
|
|
321 |
|
|
|
322 |
# Print the predicted named entities |
|
|
323 |
print("Predicted Named Entities:") |
|
|
324 |
for i in range(len(sent_tokens)): |
|
|
325 |
for j in range(len(sent_tokens[i])): |
|
|
326 |
if predicted_labels[i][j] == 'O': |
|
|
327 |
print(f"{sent_tokens[i][j]}: {predicted_labels[i][j]}") |
|
|
328 |
else: |
|
|
329 |
print(f"{sent_tokens[i][j]}: {acronyms_to_entities[predicted_labels[i][j][2:]]}") |
|
|
330 |
print("\n\n\n") |
|
|
331 |
|
|
|
332 |
display_pred(text, entities) |
|
|
333 |
# return entities |